diff --git a/index.html b/index.html index af4dce22..28832f84 100644 --- a/index.html +++ b/index.html @@ -601,11 +601,11 @@
- +
Supports Python=3.10/3.11/3.12
(tested).
Install with pip
using
The following publications utilize this software library, and refer to it as the Random Chain Motion Generator (RCMG) (more specifically the function x_xy.build_generator
):
The following publications utilize this software library, and refer to it as the Random Chain Motion Generator (RCMG) (more specifically the function ring.RCMG
):
Supports Python=3.10/3.11/3.12
(tested).
Install with pip
using
pip install 'ring @ git+https://github.com/SimiPixel/ring'
Typically, this will install jax
as cpu-only version. Afterwards, gpu-enabled version can be installed with
pip install --upgrade \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n
"},{"location":"#documentation","title":"Documentation","text":"Available here.
"},{"location":"#known-fixes","title":"Known fixes","text":""},{"location":"#offscreen-rendering-with-mujoco","title":"Offscreen rendering with Mujoco","text":"mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
Solution:
import os\nos.environ[\"MUJOCO_GL\"] = \"egl\"\n
"},{"location":"#publications","title":"Publications","text":"The following publications utilize this software library, and refer to it as the Random Chain Motion Generator (RCMG) (more specifically the function x_xy.build_generator
):
Particularly useful is the following publication from Roy Featherstone - A Beginner\u2019s Guide to 6-D Vectors (Part 2)
"},{"location":"#contact","title":"Contact","text":"Simon Bachhuber (simon.bachhuber@fau.de)
"},{"location":"api/","title":"Api","text":""},{"location":"api/#ring.ml.ringnet.RING","title":"RING
","text":"Source code in src/ring/ml/ringnet.py
class RING(ml_base.AbstractFilter):\n def __init__(self, params=None, lam=None, jit: bool = True, name=None, **kwargs):\n self.forward_lam_factory = partial(make_ring, **kwargs)\n self.params = self._load_params(params)\n self.lam = lam\n self._name = name\n\n if jit:\n self.apply = jax.jit(self.apply, static_argnames=\"lam\")\n\n def apply(self, X, params=None, state=None, y=None, lam=None):\n if lam is None:\n assert self.lam is not None\n lam = self.lam\n\n return super().apply(X, params, state, y, tuple(lam))\n\n def init(self, bs: Optional[int] = None, X=None, lam=None, seed: int = 1):\n assert X is not None, \"Providing `X` via in `ringnet.init(X=X)` is required\"\n if bs is not None:\n assert X.ndim == 4\n\n if X.ndim == 4:\n if bs is not None:\n assert bs == X.shape[0]\n else:\n bs = X.shape[0]\n X = X[0]\n\n # (T, N, F) -> (1, N, F) for faster .init call\n X = X[0:1]\n\n if lam is None:\n assert self.lam is not None\n lam = self.lam\n\n key = jax.random.PRNGKey(seed)\n params, state = self.forward_lam_factory(lam=lam).init(key, X)\n\n if bs is not None:\n state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)\n\n return params, state\n\n def _apply_batched(self, X, params, state, y, lam):\n if (params is None and self.params is None) or state is None:\n _params, _state = self.init(bs=X.shape[0], X=X, lam=lam)\n\n if params is None and self.params is None:\n params = _params\n elif params is None:\n params = self.params\n else:\n pass\n\n if state is None:\n state = _state\n\n yhat, next_state = jax.vmap(\n self.forward_lam_factory(lam=lam).apply, in_axes=(None, 0, 0)\n )(params, state, X)\n\n return yhat, next_state\n\n @staticmethod\n def _load_params(params: str | dict | None | Path):\n assert isinstance(params, (str, dict, type(None), Path))\n if isinstance(params, (Path, str)):\n return pickle_load(params)\n return params\n\n def nojit(self) -> \"RING\":\n ringnet = RING(params=self.params, lam=self.lam, jit=False)\n ringnet.forward_lam_factory = self.forward_lam_factory\n return ringnet\n\n def _pre_save(self, params=None, lam=None) -> None:\n if params is not None:\n self.params = params\n if lam is not None:\n self.lam = lam\n\n @staticmethod\n def _post_load(ringnet: \"RING\", jit: bool = True) -> \"RING\":\n if jit:\n ringnet.apply = jax.jit(ringnet.apply, static_argnames=\"lam\")\n return ringnet\n
"},{"location":"api/#ring.ml.RING_ICML24","title":"RING_ICML24(**kwargs)
","text":"Source code in src/ring/ml/__init__.py
def RING_ICML24(**kwargs):\n from pathlib import Path\n\n params = Path(__file__).parent.joinpath(\"params/0x13e3518065c21cd8.pickle\")\n ringnet = RING(params=params, **kwargs) # noqa: F811\n ringnet = base.ScaleX_FilterWrapper(ringnet)\n ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)\n ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)\n return ringnet\n
"},{"location":"api/#ring.base.System","title":"System
","text":"Source code in src/ring/base.py
@struct.dataclass\nclass System(_Base):\n link_parents: list[int] = struct.field(False)\n links: Link\n link_types: list[str] = struct.field(False)\n link_damping: jax.Array\n link_armature: jax.Array\n link_spring_stiffness: jax.Array\n link_spring_zeropoint: jax.Array\n # simulation timestep size\n dt: float = struct.field(False)\n # geometries in the system\n geoms: list[Geometry]\n # root / base acceleration offset\n gravity: jax.Array = struct.field(default_factory=lambda: jnp.array([0, 0, -9.81]))\n\n integration_method: str = struct.field(\n False, default_factory=lambda: \"semi_implicit_euler\"\n )\n mass_mat_iters: int = struct.field(False, default_factory=lambda: 0)\n\n link_names: list[str] = struct.field(False, default_factory=lambda: [])\n\n model_name: Optional[str] = struct.field(False, default_factory=lambda: None)\n\n omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])\n\n def num_links(self) -> int:\n return len(self.link_parents)\n\n def q_size(self) -> int:\n return sum([Q_WIDTHS[typ] for typ in self.link_types])\n\n def qd_size(self) -> int:\n return sum([QD_WIDTHS[typ] for typ in self.link_types])\n\n def name_to_idx(self, name: str) -> int:\n return self.link_names.index(name)\n\n def idx_to_name(self, idx: int, allow_world: bool = False) -> str:\n if allow_world and idx == -1:\n return \"world\"\n assert idx >= 0, \"Worldbody index has no name.\"\n return self.link_names[idx]\n\n def idx_map(self, type: str) -> dict:\n \"type: is either `l` or `q` or `d`\"\n dict_int_slices = {}\n\n def f(_, idx_map, name: str, link_idx: int):\n dict_int_slices[name] = idx_map[type](link_idx)\n\n self.scan(f, \"ll\", self.link_names, list(range(self.num_links())))\n\n return dict_int_slices\n\n def parent_name(self, name: str) -> str:\n return self.idx_to_name(self.link_parents[self.name_to_idx(name)])\n\n def add_prefix(self, prefix: str = \"\") -> \"System\":\n return self.replace(link_names=[prefix + name for name in self.link_names])\n\n def change_model_name(\n self,\n new_name: Optional[str] = None,\n prefix: Optional[str] = None,\n suffix: Optional[str] = None,\n ) -> \"System\":\n if prefix is None:\n prefix = \"\"\n if suffix is None:\n suffix = \"\"\n if new_name is None:\n new_name = self.model_name\n name = prefix + new_name + suffix\n return self.replace(model_name=name)\n\n def change_link_name(self, old_name: str, new_name: str) -> \"System\":\n old_idx = self.name_to_idx(old_name)\n new_link_names = self.link_names.copy()\n new_link_names[old_idx] = new_name\n return self.replace(link_names=new_link_names)\n\n def add_prefix_suffix(\n self, prefix: Optional[str] = None, suffix: Optional[str] = None\n ) -> \"System\":\n if prefix is None:\n prefix = \"\"\n if suffix is None:\n suffix = \"\"\n new_link_names = [prefix + name + suffix for name in self.link_names]\n return self.replace(link_names=new_link_names)\n\n @staticmethod\n def deep_equal(a, b):\n if type(a) is not type(b):\n return False\n if isinstance(a, _Base):\n return System.deep_equal(a.__dict__, b.__dict__)\n if isinstance(a, dict):\n if a.keys() != b.keys():\n return False\n return all(System.deep_equal(a[k], b[k]) for k in a.keys())\n if isinstance(a, (list, tuple)):\n if len(a) != len(b):\n return False\n return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))\n if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):\n return jnp.array_equal(a, b)\n return a == b\n\n def _replace_free_with_cor(self) -> \"System\":\n # check that\n # - all free joints connect to -1\n # - all joints connecting to -1 are free joints\n for i, p in enumerate(self.link_parents):\n link_type = self.link_types[i]\n if (p == -1 and link_type != \"free\") or (link_type == \"free\" and p != -1):\n raise InvalidSystemError(\n f\"link={self.idx_to_name(i)}, parent=\"\n f\"{self.idx_to_name(p, allow_world=True)},\"\n f\" joint={link_type}. Hint: Try setting `config.cor` to false.\"\n )\n\n def logic_replace_free_with_cor(name, olt, ola, old, ols, olz):\n # by default new is equal to old\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n # old link type == free\n if olt == \"free\":\n # cor joint is (free, p3d) stacked\n nlt = \"cor\"\n # entries of old armature are 3*ang (spherical), 3*pos (p3d)\n nla = jnp.concatenate((ola, ola[3:]))\n nld = jnp.concatenate((old, old[3:]))\n nls = jnp.concatenate((ols, ols[3:]))\n nlz = jnp.concatenate((olz, olz[4:]))\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)\n\n def freeze(self, name: str | list[str]):\n if isinstance(name, list):\n sys = self\n for n in name:\n sys = sys.freeze(n)\n return sys\n\n def logic_freeze(link_name, olt, ola, old, ols, olz):\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n if link_name == name:\n nlt = \"frozen\"\n nla = nld = nls = nlz = jnp.array([])\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_freeze)\n\n def unfreeze(self, name: str, new_joint_type: str):\n assert self.link_types[self.name_to_idx(name)] == \"frozen\"\n assert new_joint_type != \"frozen\"\n\n return self.change_joint_type(name, new_joint_type)\n\n def change_joint_type(\n self,\n name: str,\n new_joint_type: str,\n new_arma: Optional[jax.Array] = None,\n new_damp: Optional[jax.Array] = None,\n new_stif: Optional[jax.Array] = None,\n new_zero: Optional[jax.Array] = None,\n ):\n \"By default damping, stiffness are set to zero.\"\n q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]\n\n def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n if link_name == name:\n nlt = new_joint_type\n q_zeros = jnp.zeros((q_size))\n qd_zeros = jnp.zeros((qd_size,))\n\n nla = qd_zeros if new_arma is None else new_arma\n nld = qd_zeros if new_damp is None else new_damp\n nls = qd_zeros if new_stif is None else new_stif\n nlz = q_zeros if new_zero is None else new_zero\n\n # unit quaternion\n if new_joint_type in [\"spherical\", \"free\", \"cor\"] and new_zero is None:\n nlz = nlz.at[0].set(1.0)\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)\n\n def findall_imus(self) -> list[str]:\n return [name for name in self.link_names if name[:3] == \"imu\"]\n\n def findall_segments(self) -> list[str]:\n imus = self.findall_imus()\n return [name for name in self.link_names if name not in imus]\n\n def _bodies_indices_to_bodies_name(self, bodies: list[int]) -> list[str]:\n return [self.idx_to_name(i) for i in bodies]\n\n def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:\n bodies = [i for i, p in enumerate(self.link_parents) if p == -1]\n return self._bodies_indices_to_bodies_name(bodies) if names else bodies\n\n def find_body_to_world(self, name: bool = False) -> int | str:\n bodies = self.findall_bodies_to_world(names=name)\n assert len(bodies) == 1\n return bodies[0]\n\n def findall_bodies_with_jointtype(\n self, typ: str, names: bool = False\n ) -> list[int] | list[str]:\n bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]\n return self._bodies_indices_to_bodies_name(bodies) if names else bodies\n\n def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):\n \"\"\"Scan `f` along each link in system whilst carrying along state.\n\n Args:\n f (Callable[..., Y]): f(y: Y, *args) -> y\n in_types: string specifying the type of each input arg:\n 'l' is an input to be split according to link ranges\n 'q' is an input to be split according to q ranges\n 'd' is an input to be split according to qd ranges\n args: Arguments passed to `f`, and split to match the link.\n reverse (bool, optional): If `true` from leaves to root. Defaults to False.\n\n Returns:\n ys: Stacked output y of f.\n \"\"\"\n return _scan_sys(self, f, in_types, *args, reverse=reverse)\n\n def parse(self) -> \"System\":\n \"\"\"Initial setup of system. System object does not work unless it is parsed.\n Currently it does:\n - some consistency checks\n - populate the spatial inertia tensors\n - check that all names are unique\n - check that names are strings\n - check that all pos_min <= pos_max (unless traced)\n - order geoms in ascending order based on their parent link idx\n - check that all links have the correct size of\n - damping\n - armature\n - stiffness\n - zeropoint\n - check that n_links == len(sys.omc)\n \"\"\"\n return _parse_system(self)\n\n def render(\n self,\n xs: Optional[Transform | list[Transform]] = None,\n camera: Optional[str] = None,\n show_pbar: bool = True,\n backend: str = \"mujoco\",\n render_every_nth: int = 1,\n **scene_kwargs,\n ) -> list[np.ndarray]:\n \"\"\"Render frames from system and trajectory of maximal coordinates `xs`.\n\n Args:\n sys (base.System): System to render.\n xs (base.Transform | list[base.Transform]): Single or time-series\n of maximal coordinates `xs`.\n show_pbar (bool, optional): Whether or not to show a progress bar.\n Defaults to True.\n\n Returns:\n list[np.ndarray]: Stacked rendered frames. Length == len(xs).\n \"\"\"\n return ring.rendering.render(\n self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs\n )\n\n def render_prediction(\n self,\n xs: Transform | list[Transform],\n yhat: dict,\n stepframe: int = 1,\n # by default we don't predict the global rotation\n transparent_segment_to_root: bool = True,\n **kwargs,\n ):\n \"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.\"\n return ring.rendering.render_prediction(\n self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs\n )\n\n def delete_system(self, link_name: str | list[str], strict: bool = True):\n \"Cut subsystem starting at `link_name` (inclusive) from tree.\"\n return ring.sys_composer.delete_subsystem(self, link_name, strict)\n\n def make_sys_noimu(self, imu_link_names: Optional[list[str]] = None):\n \"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}\"\n return ring.sys_composer.make_sys_noimu(self, imu_link_names)\n\n def inject_system(self, other_system: \"System\", at_body: Optional[str] = None):\n \"\"\"Combine two systems into one.\n\n Args:\n sys (base.System): Large system.\n sub_sys (base.System): Small system that will be included into the\n large system `sys`.\n at_body (Optional[str], optional): Into which body of the large system\n small system will be included. Defaults to `worldbody`.\n\n Returns:\n base.System: _description_\n \"\"\"\n return ring.sys_composer.inject_system(self, other_system, at_body)\n\n def morph_system(\n self,\n new_parents: Optional[list[int | str]] = None,\n new_anchor: Optional[int | str] = None,\n ):\n \"\"\"Re-orders the graph underlying the system. Returns a new system.\n\n Args:\n sys (base.System): System to be modified.\n new_parents (list[int]): Let the i-th entry have value j. Then, after\n morphing the system the system will be such that the link corresponding\n to the i-th link in the old system will have as parent the link\n corresponding to the j-th link in the old system.\n\n Returns:\n base.System: Modified system.\n \"\"\"\n return ring.sys_composer.morph_system(self, new_parents, new_anchor)\n\n @staticmethod\n def from_xml(path: str, seed: int = 1):\n return ring.io.load_sys_from_xml(path, seed)\n\n @staticmethod\n def from_str(xml: str, seed: int = 1):\n return ring.io.load_sys_from_str(xml, seed)\n\n def to_str(self) -> str:\n return ring.io.save_sys_to_str(self)\n\n def to_xml(self, path: str) -> None:\n ring.io.save_sys_to_xml(self, path)\n\n @classmethod\n def create(cls, path_or_str: str, seed: int = 1) -> \"System\":\n path = Path(path_or_str).with_suffix(\".xml\")\n if path.exists():\n return cls.from_xml(path, seed=seed)\n else:\n return cls.from_str(path_or_str)\n\n def coordinate_vector_to_q(\n self,\n q: jax.Array,\n custom_joints: dict[str, Callable] = {},\n ) -> jax.Array:\n \"\"\"Map a coordinate vector `q` to the minimal coordinates vector of the sys\"\"\"\n # Does, e.g.\n # - normalize quaternions\n # - hinge joints in [-pi, pi]\n q_preproc = []\n\n def preprocess(_, __, link_type, q):\n to_q = ring.algorithms.jcalc.get_joint_model(\n link_type\n ).coordinate_vector_to_q\n # function in custom_joints has priority over JointModel\n if link_type in custom_joints:\n to_q = custom_joints[link_type]\n if to_q is None:\n raise NotImplementedError(\n f\"Please specify the custom joint `{link_type}`\"\n \" either using the `custom_joints` arguments or using the\"\n \" JointModel.coordinate_vector_to_q field.\"\n )\n new_q = to_q(q)\n q_preproc.append(new_q)\n\n self.scan(preprocess, \"lq\", self.link_types, q)\n return jnp.concatenate(q_preproc)\n
"},{"location":"api/#ring.base.System.idx_map","title":"idx_map(type)
","text":"type: is either l
or q
or d
src/ring/base.py
def idx_map(self, type: str) -> dict:\n \"type: is either `l` or `q` or `d`\"\n dict_int_slices = {}\n\n def f(_, idx_map, name: str, link_idx: int):\n dict_int_slices[name] = idx_map[type](link_idx)\n\n self.scan(f, \"ll\", self.link_names, list(range(self.num_links())))\n\n return dict_int_slices\n
"},{"location":"api/#ring.base.System.change_joint_type","title":"change_joint_type(name, new_joint_type, new_arma=None, new_damp=None, new_stif=None, new_zero=None)
","text":"By default damping, stiffness are set to zero.
Source code insrc/ring/base.py
def change_joint_type(\n self,\n name: str,\n new_joint_type: str,\n new_arma: Optional[jax.Array] = None,\n new_damp: Optional[jax.Array] = None,\n new_stif: Optional[jax.Array] = None,\n new_zero: Optional[jax.Array] = None,\n):\n \"By default damping, stiffness are set to zero.\"\n q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]\n\n def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n if link_name == name:\n nlt = new_joint_type\n q_zeros = jnp.zeros((q_size))\n qd_zeros = jnp.zeros((qd_size,))\n\n nla = qd_zeros if new_arma is None else new_arma\n nld = qd_zeros if new_damp is None else new_damp\n nls = qd_zeros if new_stif is None else new_stif\n nlz = q_zeros if new_zero is None else new_zero\n\n # unit quaternion\n if new_joint_type in [\"spherical\", \"free\", \"cor\"] and new_zero is None:\n nlz = nlz.at[0].set(1.0)\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)\n
"},{"location":"api/#ring.base.System.scan","title":"scan(f, in_types, *args, reverse=False)
","text":"Scan f
along each link in system whilst carrying along state.
Parameters:
Name Type Description Defaultf
Callable[..., Y]
f(y: Y, *args) -> y
requiredin_types
str
string specifying the type of each input arg: 'l' is an input to be split according to link ranges 'q' is an input to be split according to q ranges 'd' is an input to be split according to qd ranges
requiredargs
Arguments passed to f
, and split to match the link.
()
reverse
bool
If true
from leaves to root. Defaults to False.
False
Returns:
Name Type Descriptionys
Stacked output y of f.
Source code insrc/ring/base.py
def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):\n \"\"\"Scan `f` along each link in system whilst carrying along state.\n\n Args:\n f (Callable[..., Y]): f(y: Y, *args) -> y\n in_types: string specifying the type of each input arg:\n 'l' is an input to be split according to link ranges\n 'q' is an input to be split according to q ranges\n 'd' is an input to be split according to qd ranges\n args: Arguments passed to `f`, and split to match the link.\n reverse (bool, optional): If `true` from leaves to root. Defaults to False.\n\n Returns:\n ys: Stacked output y of f.\n \"\"\"\n return _scan_sys(self, f, in_types, *args, reverse=reverse)\n
"},{"location":"api/#ring.base.System.parse","title":"parse()
","text":"Initial setup of system. System object does not work unless it is parsed. Currently it does: - some consistency checks - populate the spatial inertia tensors - check that all names are unique - check that names are strings - check that all pos_min <= pos_max (unless traced) - order geoms in ascending order based on their parent link idx - check that all links have the correct size of - damping - armature - stiffness - zeropoint - check that n_links == len(sys.omc)
Source code insrc/ring/base.py
def parse(self) -> \"System\":\n \"\"\"Initial setup of system. System object does not work unless it is parsed.\n Currently it does:\n - some consistency checks\n - populate the spatial inertia tensors\n - check that all names are unique\n - check that names are strings\n - check that all pos_min <= pos_max (unless traced)\n - order geoms in ascending order based on their parent link idx\n - check that all links have the correct size of\n - damping\n - armature\n - stiffness\n - zeropoint\n - check that n_links == len(sys.omc)\n \"\"\"\n return _parse_system(self)\n
"},{"location":"api/#ring.base.System.render","title":"render(xs=None, camera=None, show_pbar=True, backend='mujoco', render_every_nth=1, **scene_kwargs)
","text":"Render frames from system and trajectory of maximal coordinates xs
.
Parameters:
Name Type Description Defaultsys
System
System to render.
requiredxs
Transform | list[Transform]
Single or time-series
None
show_pbar
bool
Whether or not to show a progress bar.
True
Returns:
Type Descriptionlist[ndarray]
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
Source code insrc/ring/base.py
def render(\n self,\n xs: Optional[Transform | list[Transform]] = None,\n camera: Optional[str] = None,\n show_pbar: bool = True,\n backend: str = \"mujoco\",\n render_every_nth: int = 1,\n **scene_kwargs,\n) -> list[np.ndarray]:\n \"\"\"Render frames from system and trajectory of maximal coordinates `xs`.\n\n Args:\n sys (base.System): System to render.\n xs (base.Transform | list[base.Transform]): Single or time-series\n of maximal coordinates `xs`.\n show_pbar (bool, optional): Whether or not to show a progress bar.\n Defaults to True.\n\n Returns:\n list[np.ndarray]: Stacked rendered frames. Length == len(xs).\n \"\"\"\n return ring.rendering.render(\n self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs\n )\n
"},{"location":"api/#ring.base.System.render_prediction","title":"render_prediction(xs, yhat, stepframe=1, transparent_segment_to_root=True, **kwargs)
","text":"xs
matches sys
. yhat
matches sys_noimu
. yhat
are child-to-parent.
src/ring/base.py
def render_prediction(\n self,\n xs: Transform | list[Transform],\n yhat: dict,\n stepframe: int = 1,\n # by default we don't predict the global rotation\n transparent_segment_to_root: bool = True,\n **kwargs,\n):\n \"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.\"\n return ring.rendering.render_prediction(\n self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs\n )\n
"},{"location":"api/#ring.base.System.delete_system","title":"delete_system(link_name, strict=True)
","text":"Cut subsystem starting at link_name
(inclusive) from tree.
src/ring/base.py
def delete_system(self, link_name: str | list[str], strict: bool = True):\n \"Cut subsystem starting at `link_name` (inclusive) from tree.\"\n return ring.sys_composer.delete_subsystem(self, link_name, strict)\n
"},{"location":"api/#ring.base.System.make_sys_noimu","title":"make_sys_noimu(imu_link_names=None)
","text":"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}
Source code insrc/ring/base.py
def make_sys_noimu(self, imu_link_names: Optional[list[str]] = None):\n \"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}\"\n return ring.sys_composer.make_sys_noimu(self, imu_link_names)\n
"},{"location":"api/#ring.base.System.inject_system","title":"inject_system(other_system, at_body=None)
","text":"Combine two systems into one.
Parameters:
Name Type Description Defaultsys
System
Large system.
requiredsub_sys
System
Small system that will be included into the large system sys
.
at_body
Optional[str]
Into which body of the large system small system will be included. Defaults to worldbody
.
None
Returns:
Type Descriptionbase.System: description
Source code insrc/ring/base.py
def inject_system(self, other_system: \"System\", at_body: Optional[str] = None):\n \"\"\"Combine two systems into one.\n\n Args:\n sys (base.System): Large system.\n sub_sys (base.System): Small system that will be included into the\n large system `sys`.\n at_body (Optional[str], optional): Into which body of the large system\n small system will be included. Defaults to `worldbody`.\n\n Returns:\n base.System: _description_\n \"\"\"\n return ring.sys_composer.inject_system(self, other_system, at_body)\n
"},{"location":"api/#ring.base.System.morph_system","title":"morph_system(new_parents=None, new_anchor=None)
","text":"Re-orders the graph underlying the system. Returns a new system.
Parameters:
Name Type Description Defaultsys
System
System to be modified.
requirednew_parents
list[int]
Let the i-th entry have value j. Then, after morphing the system the system will be such that the link corresponding to the i-th link in the old system will have as parent the link corresponding to the j-th link in the old system.
None
Returns:
Type Descriptionbase.System: Modified system.
Source code insrc/ring/base.py
def morph_system(\n self,\n new_parents: Optional[list[int | str]] = None,\n new_anchor: Optional[int | str] = None,\n):\n \"\"\"Re-orders the graph underlying the system. Returns a new system.\n\n Args:\n sys (base.System): System to be modified.\n new_parents (list[int]): Let the i-th entry have value j. Then, after\n morphing the system the system will be such that the link corresponding\n to the i-th link in the old system will have as parent the link\n corresponding to the j-th link in the old system.\n\n Returns:\n base.System: Modified system.\n \"\"\"\n return ring.sys_composer.morph_system(self, new_parents, new_anchor)\n
"},{"location":"api/#ring.base.System.coordinate_vector_to_q","title":"coordinate_vector_to_q(q, custom_joints={})
","text":"Map a coordinate vector q
to the minimal coordinates vector of the sys
src/ring/base.py
def coordinate_vector_to_q(\n self,\n q: jax.Array,\n custom_joints: dict[str, Callable] = {},\n) -> jax.Array:\n \"\"\"Map a coordinate vector `q` to the minimal coordinates vector of the sys\"\"\"\n # Does, e.g.\n # - normalize quaternions\n # - hinge joints in [-pi, pi]\n q_preproc = []\n\n def preprocess(_, __, link_type, q):\n to_q = ring.algorithms.jcalc.get_joint_model(\n link_type\n ).coordinate_vector_to_q\n # function in custom_joints has priority over JointModel\n if link_type in custom_joints:\n to_q = custom_joints[link_type]\n if to_q is None:\n raise NotImplementedError(\n f\"Please specify the custom joint `{link_type}`\"\n \" either using the `custom_joints` arguments or using the\"\n \" JointModel.coordinate_vector_to_q field.\"\n )\n new_q = to_q(q)\n q_preproc.append(new_q)\n\n self.scan(preprocess, \"lq\", self.link_types, q)\n return jnp.concatenate(q_preproc)\n
"},{"location":"api/#ring.base.State","title":"State
","text":"The static and dynamic state of a system in minimal and maximal coordinates. Use .create()
to create this object.
Parameters:
Name Type Description Defaultq
Array
System state in minimal coordinates (equals sys.q_size()
)
qd
Array
System velocity in minimal coordinates (equals sys.qd_size()
)
x
(Transform): Maximal coordinates of all links. From epsilon-to-link.
requiredmass_mat_inv
Array
Inverse of the mass matrix. Internal usage.
required Source code insrc/ring/base.py
@struct.dataclass\nclass State(_Base):\n \"\"\"The static and dynamic state of a system in minimal and maximal coordinates.\n Use `.create()` to create this object.\n\n Args:\n q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)\n qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)\n x: (Transform): Maximal coordinates of all links. From epsilon-to-link.\n mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.\n \"\"\"\n\n q: jax.Array\n qd: jax.Array\n x: Transform\n mass_mat_inv: jax.Array\n\n @classmethod\n def create(\n cls,\n sys: System,\n q: Optional[jax.Array] = None,\n qd: Optional[jax.Array] = None,\n x: Optional[Transform] = None,\n key: Optional[jax.Array] = None,\n custom_joints: dict[str, Callable] = {},\n ):\n \"\"\"Create state of system.\n\n Args:\n sys (System): The system for which to create a state.\n q (jax.Array, optional): The joint values of the system. Defaults to None.\n Which then defaults to zeros.\n qd (jax.Array, optional): The joint velocities of the system.\n Defaults to None. Which then defaults to zeros.\n\n Returns:\n (State): Create State object.\n \"\"\"\n if key is not None:\n assert q is None\n q = jax.random.normal(key, shape=(sys.q_size(),))\n q = sys.coordinate_vector_to_q(q, custom_joints)\n elif q is None:\n q = jnp.zeros((sys.q_size(),))\n\n # free, cor, spherical joints are not zeros but have unit quaternions\n def replace_by_unit_quat(_, idx_map, link_typ, link_idx):\n nonlocal q\n\n if link_typ in [\"free\", \"cor\", \"spherical\"]:\n q_idxs_link = idx_map[\"q\"](link_idx)\n q = q.at[q_idxs_link.start].set(1.0)\n\n sys.scan(\n replace_by_unit_quat,\n \"ll\",\n sys.link_types,\n list(range(sys.num_links())),\n )\n else:\n pass\n\n if qd is None:\n qd = jnp.zeros((sys.qd_size(),))\n\n if x is None:\n x = Transform.zero((sys.num_links(),))\n\n return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))\n
"},{"location":"api/#ring.base.State.create","title":"create(sys, q=None, qd=None, x=None, key=None, custom_joints={})
classmethod
","text":"Create state of system.
Parameters:
Name Type Description Defaultsys
System
The system for which to create a state.
requiredq
Array
The joint values of the system. Defaults to None.
None
qd
Array
The joint velocities of the system.
None
Returns:
Type DescriptionState
Create State object.
Source code insrc/ring/base.py
@classmethod\ndef create(\n cls,\n sys: System,\n q: Optional[jax.Array] = None,\n qd: Optional[jax.Array] = None,\n x: Optional[Transform] = None,\n key: Optional[jax.Array] = None,\n custom_joints: dict[str, Callable] = {},\n):\n \"\"\"Create state of system.\n\n Args:\n sys (System): The system for which to create a state.\n q (jax.Array, optional): The joint values of the system. Defaults to None.\n Which then defaults to zeros.\n qd (jax.Array, optional): The joint velocities of the system.\n Defaults to None. Which then defaults to zeros.\n\n Returns:\n (State): Create State object.\n \"\"\"\n if key is not None:\n assert q is None\n q = jax.random.normal(key, shape=(sys.q_size(),))\n q = sys.coordinate_vector_to_q(q, custom_joints)\n elif q is None:\n q = jnp.zeros((sys.q_size(),))\n\n # free, cor, spherical joints are not zeros but have unit quaternions\n def replace_by_unit_quat(_, idx_map, link_typ, link_idx):\n nonlocal q\n\n if link_typ in [\"free\", \"cor\", \"spherical\"]:\n q_idxs_link = idx_map[\"q\"](link_idx)\n q = q.at[q_idxs_link.start].set(1.0)\n\n sys.scan(\n replace_by_unit_quat,\n \"ll\",\n sys.link_types,\n list(range(sys.num_links())),\n )\n else:\n pass\n\n if qd is None:\n qd = jnp.zeros((sys.qd_size(),))\n\n if x is None:\n x = Transform.zero((sys.num_links(),))\n\n return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))\n
"},{"location":"api/#ring.algorithms.dynamics.step","title":"step(sys, state, taus=None, n_substeps=1)
","text":"Source code in src/ring/algorithms/dynamics.py
def step(\n sys: base.System,\n state: base.State,\n taus: Optional[jax.Array] = None,\n n_substeps: int = 1,\n) -> base.State:\n assert sys.q_size() == state.q.size\n if taus is None:\n taus = jnp.zeros_like(state.qd)\n assert sys.qd_size() == state.qd.size == taus.size\n assert (\n sys.integration_method.lower() == \"semi_implicit_euler\"\n ), \"Currently, nothing else then `semi_implicit_euler` implemented.\"\n\n sys = sys.replace(dt=sys.dt / n_substeps)\n\n for _ in range(n_substeps):\n # update kinematics before stepping; this means that the `x` in `state`\n # will lag one step behind but otherwise we would have to return\n # the system object which would be awkward\n sys, state = kinematics.forward_kinematics(sys, state)\n state = _integration_methods[sys.integration_method.lower()](sys, state, taus)\n\n return state\n
"},{"location":"api/#ring.base.Transform","title":"Transform
","text":"Represents the Transformation from Pl\u00fccker A to Pl\u00fccker B, where B is located relative to A at pos
in frame A and rot
is the relative quaternion from A to B.
src/ring/base.py
@struct.dataclass\nclass Transform(_Base):\n \"\"\"Represents the Transformation from Pl\u00fccker A to Pl\u00fccker B,\n where B is located relative to A at `pos` in frame A and `rot` is the\n relative quaternion from A to B.\"\"\"\n\n pos: Vector\n rot: Quaternion\n\n @classmethod\n def create(cls, pos=None, rot=None):\n assert not (pos is None and rot is None), \"One must be given.\"\n shape_rot = rot.shape[:-1] if rot is not None else ()\n shape_pos = pos.shape[:-1] if pos is not None else ()\n\n if pos is None:\n pos = jnp.zeros(shape_rot + (3,))\n if rot is None:\n rot = jnp.array([1.0, 0, 0, 0])\n rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape_pos + (1,))\n\n assert pos.shape[:-1] == rot.shape[:-1]\n\n return Transform(pos, rot)\n\n @classmethod\n def zero(cls, shape=()) -> \"Transform\":\n \"\"\"Returns a zero transform with a batch shape.\"\"\"\n pos = jnp.zeros(shape + (3,))\n rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))\n return Transform(pos, rot)\n\n def as_matrix(self) -> jax.Array:\n E = maths.quat_to_3x3(self.rot)\n return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)\n
"},{"location":"api/#ring.base.Transform.zero","title":"zero(shape=())
classmethod
","text":"Returns a zero transform with a batch shape.
Source code insrc/ring/base.py
@classmethod\ndef zero(cls, shape=()) -> \"Transform\":\n \"\"\"Returns a zero transform with a batch shape.\"\"\"\n pos = jnp.zeros(shape + (3,))\n rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))\n return Transform(pos, rot)\n
"},{"location":"api/#ring.algorithms.generator.base.RCMG","title":"RCMG
","text":"Source code in src/ring/algorithms/generator/base.py
class RCMG:\n def __init__(\n self,\n sys: base.System | list[base.System],\n config: jcalc.MotionConfig | list[jcalc.MotionConfig] = jcalc.MotionConfig(),\n setup_fn: Optional[types.SETUP_FN] = None,\n finalize_fn: Optional[types.FINALIZE_FN] = None,\n add_X_imus: bool = False,\n add_X_imus_kwargs: Optional[dict] = None,\n add_X_jointaxes: bool = False,\n add_X_jointaxes_kwargs: Optional[dict] = None,\n add_y_relpose: bool = False,\n add_y_rootincl: bool = False,\n sys_ml: Optional[base.System] = None,\n randomize_positions: bool = False,\n randomize_motion_artifacts: bool = False,\n randomize_joint_params: bool = False,\n randomize_anchors: bool = False,\n randomize_anchors_kwargs: Optional[dict] = None,\n randomize_hz: bool = False,\n randomize_hz_kwargs: Optional[dict] = None,\n imu_motion_artifacts: bool = False,\n imu_motion_artifacts_kwargs: Optional[dict] = None,\n dynamic_simulation: bool = False,\n dynamic_simulation_kwargs: Optional[dict] = None,\n output_transform: Optional[Callable] = None,\n keep_output_extras: bool = False,\n use_link_number_in_Xy: bool = False,\n ) -> None:\n\n randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)\n randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)\n\n if randomize_hz:\n finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)\n\n partial_build_gen = partial(\n _build_generator_lazy,\n setup_fn=setup_fn,\n finalize_fn=finalize_fn,\n add_X_imus=add_X_imus,\n add_X_imus_kwargs=add_X_imus_kwargs,\n add_X_jointaxes=add_X_jointaxes,\n add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,\n add_y_relpose=add_y_relpose,\n add_y_rootincl=add_y_rootincl,\n randomize_positions=randomize_positions,\n randomize_motion_artifacts=randomize_motion_artifacts,\n randomize_joint_params=randomize_joint_params,\n imu_motion_artifacts=imu_motion_artifacts,\n imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,\n dynamic_simulation=dynamic_simulation,\n dynamic_simulation_kwargs=dynamic_simulation_kwargs,\n output_transform=output_transform,\n keep_output_extras=keep_output_extras,\n use_link_number_in_Xy=use_link_number_in_Xy,\n )\n\n sys, config = utils.to_list(sys), utils.to_list(config)\n\n if randomize_anchors:\n assert (\n len(sys) == 1\n ), \"If `randomize_anchors`, then only one system is expected\"\n sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)\n\n zip_sys_config = False\n if randomize_hz:\n zip_sys_config = True\n sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)\n\n if sys_ml is None:\n # TODO\n if False and len(sys) > 1:\n warnings.warn(\n \"Batched simulation with multiple systems but no explicit `sys_ml`\"\n )\n sys_ml = sys[0]\n\n self.gens = []\n if zip_sys_config:\n for _sys, _config in zip(sys, config):\n self.gens.append(\n partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)\n )\n else:\n for _sys in sys:\n for _config in config:\n self.gens.append(\n partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)\n )\n\n def _to_data(self, sizes, seed, jit):\n return batch.batch_generators_eager_to_list(\n self.gens, sizes, seed=seed, jit=jit\n )\n\n def to_list(self, sizes: int | list[int] = 1, seed: int = 1, jit: bool = False):\n return self._to_data(sizes, seed, jit)\n\n def to_pickle(\n self,\n path: str,\n sizes: int | list[int] = 1,\n seed: int = 1,\n jit: bool = False,\n overwrite: bool = True,\n ) -> None:\n data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))\n utils.pickle_save(data, path, overwrite=overwrite)\n\n def to_hdf5(\n self,\n path: str,\n sizes: int | list[int] = 1,\n seed: int = 1,\n jit: bool = False,\n overwrite: bool = True,\n ) -> None:\n data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))\n utils.hdf5_save(path, data, overwrite=overwrite)\n\n def to_eager_gen(\n self,\n batchsize: int = 1,\n sizes: int | list[int] = 1,\n seed: int = 1,\n jit: bool = False,\n ) -> types.BatchedGenerator:\n return batch.batch_generators_eager(\n self.gens, sizes, batchsize, seed=seed, jit=jit\n )\n\n def to_lazy_gen(\n self, sizes: int | list[int] = 1, jit: bool = True\n ) -> types.BatchedGenerator:\n return batch.batch_generators_lazy(self.gens, sizes, jit=jit)\n\n @staticmethod\n def eager_gen_from_paths(\n paths: str | list[str],\n batchsize: int,\n include_samples: Optional[list[int]] = None,\n shuffle: bool = True,\n load_all_into_memory: bool = False,\n tree_transform=None,\n ) -> tuple[types.BatchedGenerator, int]:\n paths = utils.to_list(paths)\n return batch.batched_generator_from_paths(\n paths,\n batchsize,\n include_samples,\n shuffle,\n load_all_into_memory=load_all_into_memory,\n tree_transform=tree_transform,\n )\n
"},{"location":"api/#ring.algorithms.jcalc.MotionConfig","title":"MotionConfig
dataclass
","text":"Source code in src/ring/algorithms/jcalc.py
@dataclass\nclass MotionConfig:\n T: float = 60.0 # length of random motion\n t_min: float = 0.05 # min time between two generated angles\n t_max: float | TimeDependentFloat = 0.30 # max time ..\n\n dang_min: float | TimeDependentFloat = 0.1 # minimum angular velocity in rad/s\n dang_max: float | TimeDependentFloat = 3.0 # maximum angular velocity in rad/s\n\n # minimum angular velocity of euler angles used for `free and spherical joints`\n dang_min_free_spherical: float | TimeDependentFloat = 0.1\n dang_max_free_spherical: float | TimeDependentFloat = 3.0\n\n # max min allowed actual delta values in radians\n delta_ang_min: float | TimeDependentFloat = 0.0\n delta_ang_max: float | TimeDependentFloat = 2 * jnp.pi\n delta_ang_min_free_spherical: float | TimeDependentFloat = 0.0\n delta_ang_max_free_spherical: float | TimeDependentFloat = 2 * jnp.pi\n\n dpos_min: float | TimeDependentFloat = 0.001 # speed of translation\n dpos_max: float | TimeDependentFloat = 0.7\n pos_min: float | TimeDependentFloat = -2.5\n pos_max: float | TimeDependentFloat = +2.5\n\n # used by both `random_angle_*` and `random_pos_*`\n # only used if `randomized_interpolation` is set\n cdf_bins_min: int = 5\n # by default equal to `cdf_bins_min`\n cdf_bins_max: Optional[int] = None\n\n # flags\n randomized_interpolation_angle: bool = False\n randomized_interpolation_position: bool = False\n interpolation_method: str = \"cosine\"\n range_of_motion_hinge: bool = True\n range_of_motion_hinge_method: str = \"uniform\"\n\n # initial value of joints\n ang0_min: float = -jnp.pi\n ang0_max: float = jnp.pi\n pos0_min: float = 0.0\n pos0_max: float = 0.0\n\n # cor (center of rotation) custom fields\n cor: bool = False\n cor_t_min: float = 0.2\n cor_t_max: float | TimeDependentFloat = 2.0\n cor_dpos_min: float | TimeDependentFloat = 0.00001\n cor_dpos_max: float | TimeDependentFloat = 0.5\n cor_pos_min: float | TimeDependentFloat = -0.4\n cor_pos_max: float | TimeDependentFloat = 0.4\n\n def is_feasible(self) -> bool:\n return _is_feasible_config1(self)\n\n def to_nomotion_config(self) -> \"MotionConfig\":\n kwargs = asdict(self)\n for key in [\n \"dang_min\",\n \"dang_max\",\n \"delta_ang_min\",\n \"dang_min_free_spherical\",\n \"dang_max_free_spherical\",\n \"delta_ang_min_free_spherical\",\n \"dpos_min\",\n \"dpos_max\",\n ]:\n kwargs[key] = 0.0\n nomotion_config = MotionConfig(**kwargs)\n assert nomotion_config.is_feasible()\n return nomotion_config\n
"},{"location":"api/#ring.algorithms.jcalc.register_new_joint_type","title":"register_new_joint_type(joint_type, joint_model, q_width, qd_width=None, overwrite=False)
","text":"Source code in src/ring/algorithms/jcalc.py
def register_new_joint_type(\n joint_type: str,\n joint_model: JointModel,\n q_width: int,\n qd_width: Optional[int] = None,\n overwrite: bool = False,\n):\n # this name is used\n assert joint_type != \"default\", \"Please use another name.\"\n\n exists = joint_type in _joint_types\n if exists and overwrite:\n for dic in [\n base.Q_WIDTHS,\n base.QD_WIDTHS,\n _joint_types,\n ]:\n dic.pop(joint_type)\n else:\n assert (\n not exists\n ), f\"joint type `{joint_type}`already exists, use `overwrite=True`\"\n\n if qd_width is None:\n qd_width = q_width\n\n assert len(joint_model.motion) == qd_width\n\n _joint_types.update({joint_type: joint_model})\n base.Q_WIDTHS.update({joint_type: q_width})\n base.QD_WIDTHS.update({joint_type: qd_width})\n
"},{"location":"api/#ring.algorithms.jcalc.JointModel","title":"JointModel
dataclass
","text":"Source code in src/ring/algorithms/jcalc.py
@dataclass\nclass JointModel:\n # (q, params) -> Transform\n transform: Callable[[jax.Array, jax.Array], base.Transform]\n # len(motion) == len(qd)\n # if callable: joint_params -> base.Motion\n motion: list[base.Motion | Callable[[jax.Array], base.Motion]] = field(\n default_factory=lambda: []\n )\n # (config, key_t, key_value, params) -> jax.Array\n rcmg_draw_fn: Optional[DRAW_FN] = None\n\n # only used by `pd_control`\n p_control_term: Optional[P_CONTROL_TERM] = None\n qd_from_q: Optional[QD_FROM_Q] = None\n\n # used by\n # -`inverse_kinematics_endeffector`\n # - System.coordinate_vector_to_q\n coordinate_vector_to_q: Optional[COORDINATE_VECTOR_TO_Q] = None\n\n # only used by `inverse_kinematics`\n inv_kin: Optional[INV_KIN] = None\n\n init_joint_params: Optional[INIT_JOINT_PARAMS] = None\n\n utilities: Optional[dict[str, Any]] = field(default_factory=lambda: dict())\n
"},{"location":"api/#ring.algorithms.jcalc.join_motionconfigs","title":"join_motionconfigs(configs, boundaries)
","text":"Source code in src/ring/algorithms/jcalc.py
def join_motionconfigs(\n configs: list[MotionConfig], boundaries: list[float]\n) -> MotionConfig:\n assert len(configs) == (\n len(boundaries) + 1\n ), \"length of `boundaries` should be one less than length of `configs`\"\n boundaries = jnp.array(boundaries, dtype=float)\n\n def new_value(field: str):\n scalar_options = jnp.array([getattr(c, field) for c in configs])\n\n def scalar(t):\n return jax.lax.dynamic_index_in_dim(\n scalar_options, _find_interval(t, boundaries), keepdims=False\n )\n\n return scalar\n\n hints = get_type_hints(MotionConfig())\n attrs = MotionConfig().__dict__\n is_time_dependent_field = lambda key: hints[key] == (float | TimeDependentFloat)\n time_dependent_fields = [key for key in attrs if is_time_dependent_field(key)]\n time_independent_fields = [key for key in attrs if not is_time_dependent_field(key)]\n\n for time_dep_field in time_independent_fields:\n field_values = set([getattr(config, time_dep_field) for config in configs])\n assert (\n len(field_values) == 1\n ), f\"MotionConfig.{time_dep_field}={field_values}. Should be one unique value..\"\n\n changes = {field: new_value(field) for field in time_dependent_fields}\n return replace(configs[0], **changes)\n
"},{"location":"notebooks/batched_simulation/","title":"Batched simulation","text":"Note
This example is available as a jupyter notebook here.
System
object is a registered Jax-PyTree. This means it's a nested array.
This enables us to stack multiple systems (or states) to enable vectorized operations.
import ring\n\nimport jax\nimport jax.numpy as jnp\n\n\nxml_str = \"\"\"\n<x_xy model=\"double_pendulum\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"2\" euler=\"0 90 0\" joint=\"ry\" name=\"upper\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body damping=\"2\" joint=\"ry\" name=\"lower\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\nstate = ring.State.create(sys)\n
# second system with gravity disabled\nsys_nograv = sys.replace(gravity = sys.gravity * 0.0)\nsys_batched = sys.batch(sys_nograv)\n\nnext_state_batched = jax.vmap(ring.step, in_axes=(0, None))(sys_batched, state)\n
# note how the state of the system without gravity has not changed at all\nnext_state_batched.q\n
\nArray([[-1.7982468e-10, 2.3305433e-10],\n [ 0.0000000e+00, 0.0000000e+00]], dtype=float32)
\n
second_state = ring.State.create(sys, qd=jnp.ones((2,)))\nstate_batched = state.batch(second_state)\nnext_state_batched = jax.vmap(ring.step, in_axes=(None, 0))(sys, state_batched)\n
next_state_batched.q\n
\nArray([[-1.7982468e-10, 2.3305433e-10],\n [ 1.0048340e-02, 9.8215193e-03]], dtype=float32)
\n
Batched kinematic simulation is done by providing the sizes
argument to build_generator
batchsize = 8\nseed = 1\ngen = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), keep_output_extras=True).to_lazy_gen(batchsize)\n(X, y), (_, q, x, _) = gen(jax.random.PRNGKey(seed))\n
q.shape\n
\n(8, 1000, 2)
\n
\n
"},{"location":"notebooks/batched_simulation/#batched-dynamical-simulation","title":"Batched Dynamical Simulation","text":""},{"location":"notebooks/batched_simulation/#batched-system","title":"Batched System","text":"I.e. simulating two different system with the same initial state.
"},{"location":"notebooks/batched_simulation/#batched-state","title":"Batched State","text":""},{"location":"notebooks/batched_simulation/#batched-kinematic-simulation","title":"Batched Kinematic Simulation","text":""},{"location":"notebooks/control/","title":"Control","text":"Note
This example is available as a jupyter notebook here.
import ring\n\nfrom ring.algorithms.generator.pd_control import _pd_control\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nimport mediapy as media\n
The step
function also takes generalized forces tau
applied to the degrees of freedom its third input step(sys, state, taus)
.
Let's consider an inverted pendulum on a cart, and apply a left-right force onto the cart such that the pole stays in the upright position.
xml_str = \"\"\"\n<x_xy model=\"inv_pendulum\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<defaults>\n<geom color=\"white\" edge_color=\"black\"></geom>\n</defaults>\n<worldbody>\n<body damping=\"0.01\" joint=\"px\" name=\"cart\">\n<geom dim=\"0.4 0.1 0.1\" mass=\"1\" type=\"box\"></geom>\n<body damping=\"0.01\" euler=\"0 -90 0\" joint=\"ry\" name=\"pendulum\">\n<geom dim=\"1 0.1 0.1\" mass=\"0.5\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\nstate = ring.State.create(sys, q=jnp.array([0.0, 0.2])) \n\nxs = []\nT = 10.0\nfor t in range(int(T / sys.dt)):\n measurement_noise = np.random.normal() * 5\n phi = jnp.rad2deg(state.q[1]) + measurement_noise\n cart_motor_input = 0.1 * phi * abs(phi)\n taus = jnp.clip(jnp.array([cart_motor_input, 0.0]), -10, 10) \n state = jax.jit(ring.step)(sys, state, taus)\n xs.append(state.x)\n
def show_video(sys, xs: list[ring.Transform]):\n assert sys.dt == 0.01\n # only render every fourth to get a framerate of 25 fps\n frames = sys.render(xs, render_every_nth=4, camera=\"c\", add_cameras={-1: '<camera mode=\"targetbody\" name=\"c\" pos=\"0 -2 2\" target=\"0\"></camera>'})\n media.show_video(frames, fps=25)\n\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250/250 [00:01<00:00, 174.21it/s]\n
\n
This browser does not support the video tag. xml_str = \"\"\"\n<x_xy>\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"0.01\" euler=\"0 90 0\" joint=\"ry\" name=\"pendulum\" pos=\"0 0 1\">\n<geom dim=\"1 0.1 0.1\" mass=\"0.5\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\nP, D = jnp.array([10.0]), jnp.array([1.0])\n\ndef simulate_pd_control(sys, P, D):\n controller = _pd_control(P, D)\n # reference signal\n q_ref = jnp.ones((1000, 1)) * jnp.pi / 2\n controller_state = controller.init(sys, q_ref)\n state = ring.State.create(sys) \n\n xs = []\n T = 5.0\n for t in range(int(T / sys.dt)):\n controller_state, taus = jax.jit(controller.apply)(controller_state, sys, state)\n state = jax.jit(ring.step)(sys, state, taus)\n xs.append(state.x)\n return xs\n
xs = simulate_pd_control(sys, P, D)\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 125/125 [00:00<00:00, 165.26it/s]\n
\n
This browser does not support the video tag. Note the steady state error. This is because we have gravity and no Integral part (so no PID control).
If we remove gravity the steady state error also vanishes (as is expected.)
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)\nxs = simulate_pd_control(sys_nograv, P, D)\nshow_video(sys_nograv, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 125/125 [00:00<00:00, 132.06it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"notebooks/control/#balance-an-inverted-pendulum-on-a-cart","title":"Balance an inverted Pendulum on a cart","text":""},{"location":"notebooks/control/#pd-control","title":"PD Control","text":""},{"location":"notebooks/custom_joint_type/","title":"Custom joint type","text":"Note
This example is available as a jupyter notebook here.
In this notebook we will define a new joint type that is a hinge joint with a random joint axes direction.
It will also support dynamical simulation.
import ring\nfrom ring import maths, base\n\nimport jax\nimport jax.numpy as jnp\n\nimport mediapy as media\n\nfrom ring.algorithms.jcalc import _draw_rxyz\n
We will give this new joint type the identifier rr
(random revolute). Although it actually already exists in the library, but we can overwrite it.
# we use such a `params` input to specify the joint-axes, if we later then randomize the attribute of the system object\n# we will have the effect of a hinge joint with a randomized joint axes direction\n# here we tell the library how it should initialize this `params` PyTree\ndef _draw_random_joint_axis(key):\n return maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key))\n\ndef _rr_init_joint_params(key):\n return dict(joint_axes=_draw_random_joint_axis(key))\n\n# next, we tell the library how it can randomly draw a trajectory for its generalized coordinate; the hinge joint angle\ndef _rr_transform(q, params):\n # here we use this `params` object\n axis = params[\"joint_axes\"]\n q = jnp.squeeze(q)\n rot = maths.quat_rot_axis(axis, q)\n return ring.Transform.create(rot=rot)\n\n# this tells the library how to dynamically simulate the type of joint\ndef _motion_fn(params):\n return base.Motion.create(ang=params[\"joint_axes\"])\n\n# now, we can put it all together into a new `x_xy.JointModel`\nrr_joint = ring.JointModel(_rr_transform, motion=[_motion_fn], rcmg_draw_fn=_draw_rxyz, init_joint_params=_rr_init_joint_params)\n\n# and then we register the joint; Note that `overwrite`=True, because it already exists; that way you can e.g. overwrite the\n# default joint types such as the free joint\nring.register_new_joint_type(\"rr\", rr_joint, q_width=1, qd_width=1, overwrite=True)\n
xml_str = \"\"\"\n<x_xy>\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<body damping=\".01\" joint=\"rr\" name=\"pendulum\" pos=\"0 0 0.5\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.5 0.1 0.1\" mass=\"0.5\" pos=\"0.25 0 0\" type=\"box\"></geom>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\n# this seed determines (among other things) the randomness of the joint-axes direction\n# via the above specified `_rr_init_joint_params`\nseed: int = 2\nsys = ring.System.create(xml_str, seed=seed)\n
state = ring.State.create(sys)\nxs = []\nfor t in range(500):\n state = jax.jit(ring.step)(sys, state)\n xs.append(state.x)\n
sys.links.joint_params\n
\n{'rr': {'joint_axes': Array([[ 0.41278404, -0.6329913 , 0.65492845]], dtype=float32)},\n 'default': Array([], shape=(1, 0), dtype=float32)}
\n
def show_video(sys, xs: list[ring.Transform]):\n frames = sys.render(xs, render_every_nth=4)\n media.show_video(frames, fps=25)\n\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 125/125 [00:00<00:00, 158.41it/s]\n
\n
This browser does not support the video tag. the class x_xy.RCMG
already has the built-in flag randomize_joint_params
which can be toggled in order to use the user-provided logic _rr_init_joint_params
for randomizing the joint parameters
(X, y), (key, q, x, _) = ring.RCMG(sys, randomize_joint_params=True, keep_output_extras=True).to_list()[0]\n
\neager data generation: 1it [00:02, 2.24s/it]\n
\n
show_video(sys, x)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1500/1500 [00:08<00:00, 169.89it/s]\n
\n
This browser does not support the video tag. but for dynamic_simulation
flag to work we additional need to specify the function ring.JointModel.p_control_term
print(rr_joint.p_control_term)\n
\nNone\n
\n
try:\n (X, y), (key, q, x, _) = ring.RCMG(sys, randomize_joint_params=True, keep_output_extras=True, dynamic_simulation=True).to_list()[0]\nexcept NotImplementedError:\n print(\"NotImplementedError: Please specify `JointModel.p_control_term` for joint type `rr`\")\n
\neager data generation: 0it [00:00, ?it/s]
\n
\nNotImplementedError: Please specify `JointModel.p_control_term` for joint type `rr`\n
\n
\n\n
\n
\n
"},{"location":"notebooks/custom_joint_type/#defining-a-custom-joint-type-that-supports-dynamical-simulation","title":"Defining a custom Joint Type that supports dynamical simulation","text":""},{"location":"notebooks/error_quaternion/","title":"Error quaternion","text":"Note
This example is available as a jupyter notebook here.
In this notebook we will talk about what functions you need to do ML with quaternions. After all the purpose of this library is to create training data.
Typically, this involves quaternions as target values (to be predicted), similar to an orientation estimation filter (like VQF).
So, suppose you want to train some ML model that predicts a quaternion \\(\\hat{q} = f_\\theta(X)\\).
import ring\nimport jax \nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n
# suppose a 6D IMU input\nfeature_dim = 6\n\nparams = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))\ndef neural_network(params, X):\n q_unnormalized = params@X\n norm = jnp.linalg.norm(q_unnormalized)\n return q_unnormalized / norm\n\n\ndef loss_fn(params, X, y):\n q, qhat = y, neural_network(params, X)\n # squared angle error\n return ring.maths.angle_error(q, qhat)**2\n
But this is dangerous as this might lead to NaNs.
X = jnp.zeros((6,))\ny = jnp.array([1.0, 0, 0, 0])\nloss_fn(params, X, y)\n
\nArray(nan, dtype=float32)
\n
We could try to fix is by adding a small number in the divison.
# suppose a 6D IMU input\nfeature_dim = 6\n\nparams = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))\ndef neural_network(params, X):\n q_unnormalized = params@X\n norm = jnp.linalg.norm(q_unnormalized)\n eps = 1e-8\n return q_unnormalized / (norm + eps)\n\n\ndef loss_fn(params, X, y):\n q, qhat = y, neural_network(params, X)\n # squared angle error\n return ring.maths.angle_error(q, qhat)**2\n\nX = jnp.zeros((6,))\ny = jnp.array([1.0, 0, 0, 0])\nloss_fn(params, X, y)\n
\nArray(0., dtype=float32)
\n
But, still the gradient required for backpropagation gives NaNs.
jax.grad(loss_fn)(params, X, y)\n
\nArray([[nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan]], dtype=float32)
\n
The solution is a little involved. TLDR; Use x_xy.maths.safe_normalize
# suppose a 6D IMU input\nfeature_dim = 6\n\nparams = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))\ndef neural_network(params, X):\n q_unnormalized = params@X\n return ring.maths.safe_normalize(q_unnormalized)\n\n\ndef loss_fn(params, X, y):\n q, qhat = y, neural_network(params, X)\n # squared angle error\n return ring.maths.angle_error(q, qhat)**2\n\nX = jnp.zeros((6,))\ny = jnp.array([1.0, 0, 0, 0])\nloss_fn(params, X, y)\n
\nArray(0., dtype=float32)
\n
jax.grad(loss_fn)(params, X, y)\n
\nArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32)
\n
Let's take a closer look at the function x_xy.maths.angle_error
which was used in the loss_fn
in the above.
What is the behaviour of the error function (sort of the metric) between two quaternions as one approaches the other?
A first implementation might look like this:
def quat_error(q, qhat):\n q_error = ring.maths.quat_mul(ring.maths.quat_inv(q), qhat)\n phi = 2 * jnp.arccos(q_error[0])\n return jnp.abs(phi)\n
Let's reduce this function to the critical operation phi = ...
and let's assume, without loss of generality, that the target quaternion is the identity quaternion.
Then, this effectively becomes about extracting the angle from a quaternion safely.
def quat_angle(q):\n return 2 * jnp.arccos(q[0])\n
input_angles = jnp.linspace(-0.005, 0.005, num=1000)\n\ndef input_to_output_angles_incorrect(angle):\n q = ring.maths.quat_rot_axis(jnp.array([1.0, 0, 0]), angle)\n return quat_angle(q)\n\ndef input_to_output_angles_correct(angle):\n q = ring.maths.quat_rot_axis(jnp.array([1.0, 0, 0]), angle)\n return ring.maths.quat_angle(q)\n
plt.plot(input_angles, jax.vmap(input_to_output_angles_incorrect)(input_angles), label=\"incorrect\")\nplt.plot(input_angles, jax.vmap(input_to_output_angles_correct)(input_angles), label=\"correct\")\nplt.legend()\nplt.show()\n
As one might expect, the gradients are also much more stable.
plt.plot(input_angles, jax.vmap(jax.grad(input_to_output_angles_incorrect))(input_angles), label=\"incorrect\")\nplt.plot(input_angles, jax.vmap(jax.grad(input_to_output_angles_correct))(input_angles), label=\"correct\")\nplt.legend()\nplt.show()\n
\n
"},{"location":"notebooks/error_quaternion/#the-error-quaternion-required-for-ml-purposes","title":"The error quaternion (required for ML purposes)","text":""},{"location":"notebooks/error_quaternion/#how-to-get-a-quaternion-as-network-output","title":"How to get a quaternion as network output?","text":"That's easy enough. You normalize a four dimensional vector.
"},{"location":"notebooks/error_quaternion/#a-closer-look-at-the-function-x_xymathsangle_error","title":"A closer look at the functionx_xy.maths.angle_error
","text":""},{"location":"notebooks/error_quaternion/#pytorch-library-for-quaternion-operations","title":"Pytorch library for quaternion operations","text":"These functions are for JAX, but the following should work for PyTorch -> https://naver.github.io/roma/
"},{"location":"notebooks/experimental_data/","title":"Experimental data","text":"Note
This example is available as a jupyter notebook here.
import ring\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport mediapy as media\n\ndef show_video(sys: ring.System, xs: ring.Transform) -> None:\n assert sys.dt == 0.01\n # only render every fourth to get a framerate of 25 fps\n frames = sys.render(xs, camera=\"c\", height=480, width=640, render_every_nth=4,\n add_cameras={-1: '<camera mode=\"targetbody\" name=\"c\" pos=\".5 -.5 1.25\" target=\"3\"></camera>'})\n media.show_video(frames, fps=25)\n
Experimental data and system definitions of the experimental setup are located in..
from ring import exp\n
Multiple experimental trials are available. They have exp_id
s and motion_start
s and motion_stop
s
exp_id = \"S_06\"\nsys = exp.load_sys(exp_id)\n
Let's first take a look at the system that was used in the experiments.
state = ring.State.create(sys)\n# update the maximal coordinates\nxs = ring.algorithms.forward_kinematics(sys, state)[1].x\n
show_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1/1 [00:00<00:00, 7.88it/s]\n
\n
This browser does not support the video tag. As you can see a five segment kinematic chain was moved, and for each segment IMU measurements and OMC ground truth is available.
Let's load this (no simulated) IMU and OMC data.
# `canonical` is the identifier of the first motion pattern performed in this trial\n# `shaking` is the identifier of the last motion pattern performed in this trial\nmotion_start = \"canonical\"\ndata = exp.load_data(exp_id, motion_start=motion_start)\n
data.keys()\n
\ndict_keys(['seg1', 'seg2', 'seg3', 'seg4', 'seg5'])
\n
data[\"seg1\"].keys()\n
\ndict_keys(['imu_flex', 'imu_rigid', 'marker1', 'marker2', 'marker3', 'marker4', 'quat'])
\n
data[\"seg1\"][\"imu_rigid\"].keys()\n
\ndict_keys(['acc', 'gyr', 'mag'])
\n
The quaternion quat
is to be interpreted as the rotation from segment to an arbitrary OMC inertial frame.
The position marker1
is to be interpreted as the position vector from arbitrary OMC inertial frame to a specific marker (marker 1) on the respective segment (vector given in the OMC inertial frame).
Then, for each segment actually two IMUs are attached to it. One is rigidly attached, one is non-rigidly attached (via foam).
Also, how long is the trial?
data[\"seg1\"][\"marker1\"].shape\n
\n(14200, 3)
\n
It's 325 seconds of data.
Let's take a look at the motion of the whole trial.
To render it, we need maximal coordinates xs
of all links in the system.
X, y, xs, xs_noimu = exp.benchmark_fn(exp.IMTP(segments=sys.findall_segments()), exp_id, motion_start)\n
show_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 3550/3550 [00:30<00:00, 118.12it/s]\n
\n
This browser does not support the video tag. Perfect. This is a rendered animation of the real experimental motion that was performed. You can see that the spacing between segments is not perfect.
This is because in our idealistic system model joints have no spatial dimension but in reality they have. The entire setup is 3D printed, and the joints are also several centimeters long.
The segments are 20cm long.
We can use this experimental data to validate our simulated approaches or validate ML models that are learned on simulated training data.
\n
"},{"location":"notebooks/experimental_data/#loading-and-working-with-experimental-data","title":"Loading and working with experimental data","text":""},{"location":"notebooks/getting_started/","title":"Getting started","text":"Note
This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'ring @ git+https://github.com/SimiPixel/ring'\")\n os.system(\"pip install --quiet mediapy\")\n os.system(\"pip install --quiet matplotlib\")\n
import ring\n# automatically detects colab or not\nring.utils.setup_colab_env()\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport mediapy as media\n
Systems are defined with the following xml syntax.
xml_str = \"\"\"\n<x_xy model=\"double_pendulum\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"2\" euler=\"0 90 0\" joint=\"ry\" name=\"upper\" pos=\"0 0 2\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body damping=\"2\" joint=\"ry\" name=\"lower\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n
With this xml description of the system, we are ready to load the system using load_sys_from_str
. We can also save this to a text-file double_pendulum.xml
and load with load_sys_from_xml
.
sys = ring.System.create(xml_str)\n
sys.model_name\n
\n'double_pendulum'
\n
System objects have many attributes. You may refer to the API documentation for more details.
sys.link_names\n
\n['upper', 'lower']
\n
Let's start with the most obvious. A physical simulation. We refer to it as \"dynamical simulation\", in contrast to what we do a little later which is a purely kinematic simulation.
First, we have to create the dynamical state of the system. It is defined by the all degrees of freedom in the system and their velocities. Here, we have two revolute joints (one degree of freedom). Thus, the minimal coordinates vector \\(q\\) and minimal velocity vector \\(q'\\) has two dimensions.
state = ring.State.create(sys)\n
state.q\n
\nArray([0., 0.], dtype=float32)
\n
state.qd\n
\nArray([0., 0.], dtype=float32)
\n
next_state = ring.step(sys, state)\n
Massive speedups if we use jax.jit
to jit-compile the function.
%timeit ring.step(sys, state)\n
\n196 ms \u00b1 5.62 ms per loop (mean \u00b1 std. dev. of 7 runs, 1 loop each)\n
\n
%timeit jax.jit(ring.step)(sys, state)\n
\n90.2 \u00b5s \u00b1 41.4 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 1 loop each)\n
\n
Let's unroll the dynamics for multiple timesteps.
T = 10.0\nxs = []\nfor _ in range(int(T / sys.dt)):\n state = jax.jit(ring.step)(sys, state)\n xs.append(state.x)\n
Next, let's render the frames and create an animation.
frames = sys.render(xs, camera=\"targetfar\")\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 176.53it/s]\n
\n
def show_video(frames: list[np.ndarray], dt: float):\n assert dt == 0.01\n # frames are at 100 Hz, but let's create an animation at 25Hz\n media.show_video([frames[i][..., :3] for i in range(0, len(frames), 4)], fps=25)\n\nshow_video(frames, sys.dt)\n
This browser does not support the video tag. Hmm, pretty boring. Let's get the pendulum into an configuration with some potential energy.
All we have to change is the initial state state.q
.
state = ring.State.create(sys, q=jnp.array([jnp.pi / 2, 0]))\n
T = 10.0\nxs = []\nfor _ in range(int(T / sys.dt)):\n state = jax.jit(ring.step)(sys, state)\n xs.append(state.x)\n\nframes = sys.render(xs, camera=\"targetfar\")\nshow_video(frames, sys.dt)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 177.75it/s]\n
\n
This browser does not support the video tag. That's more like it!
Next, we will take a look at \"kinematic simulation\".
Let's start with why you would want this.
Imagine we want to learn a filter that estimates some quantity of interest from some sensor input.
Then, we could try to create many random motions, record the measured sensor input, and the ground truth quantity of interest target values.
This is then used as training data for a Machine Learning model.
The general interface to kinematic simulation is via x_xy.RCMG
.
This class can then create - a function (of type Generator
) that maps a PRNG seed to, e.g., X, y
data. - a list of data - data on disk (saved via pickle or hdf5)
(X, y), (key, q, xs, _) = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), keep_output_extras=True).to_list()[0]\n
\neager data generation: 1it [00:01, 1.95s/it]\n
\n
frames = sys.render(xs, camera=\"targetfar\")\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 177.14it/s]\n
\n
This is now completely random, but unphysical motion. It's only kinematics, but that is okay for creating training data.
show_video(frames, sys.dt)\n
This browser does not support the video tag. We are interested in simulating IMU data as input X
, and estimating quaternions as target y
.
We can easily simulate an IMU with only the trajectory of maximal coordinates xs
.
Suppose, we want to simulate an IMU right that is placed on the lower
segment and right at the revolute joint.
This is exactly where the coordinate system of the lower
segment is placed.
Right now the xs
trajectory contains both coordinate sytems of upper
and lower
.
# (n_timesteps, n_links, 3)\nxs.pos.shape\n
\n(1000, 2, 3)
\n
# (n_timesteps, n_links, 4)\nxs.rot.shape\n
\n(1000, 2, 4)
\n
From the axis with length two, the 0th entry is for upper
and the 1st entry is for lower
.
sys.name_to_idx(\"upper\")\n
\n0
\n
sys.name_to_idx(\"lower\")\n
\n1
\n
xs_lower = xs.take(1, axis=1)\n
imu_lower = ring.algorithms.imu(xs_lower, sys.gravity, sys.dt)\n
imu_lower.keys()\n
\ndict_keys(['acc', 'gyr'])
\n
plt.grid()\nplt.plot(np.arange(0, 10.0, step=sys.dt), imu_lower[\"gyr\"], label=[\"x\", \"y\", \"z\"])\nplt.ylabel(\"gyro [rad / s]\")\nplt.xlabel(\"time [s]\")\nplt.legend()\nplt.show()\n
As you can see it's a two-dimensional problem, which is why only one (y
) is non-zero.
Let's consider a larger kinematic chain in free 3D space.
xml_str = \"\"\"\n<x_xy model=\"three_segment_kinematic_chain\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg2\" pos=\"0 0 2\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg1\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.5 0 0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rz\" name=\"seg3\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.5 0 -0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\ndata = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), add_X_imus=True,\n add_y_relpose=True, keep_output_extras=True).to_list()\n\n# with `keep_output_extras` really everything one could possibly imagine is returned\n(X, y), (key, qs, xs, sys_mod) = data[0]\n\nframes = sys.render(xs, camera=\"targetfar\")\nshow_video(frames, sys.dt)\n
\neager data generation: 1it [00:05, 5.23s/it]\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 182.81it/s]\n
\n
This browser does not support the video tag. The two orange boxes on segment 1 and segment 3 are modelling our two IMUs. This will be the network's input X
data.
As target we will try to estimate both relative orientations as y
data.
X.keys()\n
\ndict_keys(['seg1', 'seg2', 'seg3'])
\n
X[\"seg1\"].keys()\n
\ndict_keys(['acc', 'gyr'])
\n
y.keys()\n
\ndict_keys(['seg1', 'seg3'])
\n
plt.grid()\nplt.plot(np.arange(0, 10.0, step=sys.dt), X[\"seg1\"][\"gyr\"], label=[\"x\", \"y\", \"z\"])\nplt.ylabel(\"gyro [rad / s]\")\nplt.xlabel(\"time [s]\")\nplt.title(\"IMU 1 Gyroscope\")\nplt.legend()\nplt.show()\n
Now, the IMU is non-zero in all three x/y/z
components.
plt.grid()\nplt.plot(np.arange(0, 10.0, step=sys.dt), y[\"seg1\"], label=[\"w\", \"x\", \"y\", \"z\"])\nplt.xlabel(\"time [s]\")\nplt.title(\"Relative quaternion from seg2 to seg1\")\nplt.legend()\nplt.show()\n
Note how the relative quaternion is only around the y-axis. Can you see why? (Hint: Check the defining xml_str
.)
\n
"},{"location":"notebooks/getting_started/#dynamical-simulation","title":"Dynamical Simulation","text":""},{"location":"notebooks/getting_started/#kinematic-simulation","title":"Kinematic Simulation","text":""},{"location":"notebooks/getting_started/#x-y-training-data-attaching-sensors","title":"X, y
Training data / Attaching sensors","text":""},{"location":"notebooks/imu_modeling/","title":"Imu modeling","text":"from x_xy.subpkgs import exp\nimport matplotlib.pyplot as plt\nimport jax\nimport x_xy\nimport jax.numpy as jnp\n\nhz = 100\nmarkerMap = {\n \"seg1\": 2,\n \"seg5\": 2,\n \"seg2\": 1,\n \"seg3\": 2,\n \"seg4\": 4\n}\n\ndef load_data(seg: str, t1: float, t2: float, motion: str = \"fast\"):\n\n data = exp.load_data(\"S_06\", motion)[seg]\n\n # extract a small window from long time series for plotting\n pos, rot, imu_data = jax.tree_map(lambda arr: arr[int(t1 * hz): int(t2 * hz)], \n (data[f\"marker{markerMap[seg]}\"], data[\"quat\"], data[\"imu_rigid\"]))\n rot = x_xy.maths.quat_inv(rot)\n\n # maximal coordinates of segment, there is (almost) no sensor-to-segment orientation\n xs = x_xy.Transform.create(pos, rot)\n return pos, rot, xs, imu_data\n\n\nt1, t2 = 3.0, 9.0\npos, rot, xs, imu_data = load_data(\"seg1\", t1, t2)\n
Remove gravity from accelerometer to better compare.
def linear_acceleration(xs: x_xy.Transform, acc: jax.Array) -> jax.Array:\n q_E2Imu = xs.rot\n q_Imu2E = x_xy.maths.quat_inv(q_E2Imu)\n gravity = jnp.array([0, 0, 9.81])\n acc_E_nograv = x_xy.maths.rotate(acc, q_Imu2E) - gravity\n return x_xy.maths.rotate(acc_E_nograv, q_E2Imu)\n\nimu_data[\"acc\"] = linear_acceleration(xs, imu_data[\"acc\"])\n
def plot_imu(imu_data: dict):\n imu_data = jax.tree_map(lambda arr: arr[:-100], imu_data.copy())\n fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n ts = jnp.arange(t1, t2 - 1.0, step=1 / hz)\n axes[0].plot(ts, imu_data[\"acc\"], label=[\"x\", \"y\", \"z\"])\n axes[1].plot(ts, imu_data[\"gyr\"], label=[\"x\", \"y\", \"z\"])\n for ax in axes:\n ax.grid()\n ax.set_xlabel(\"time [s]\")\n ax.legend()\n axes[0].set_title(\"Acc\")\n axes[1].set_title(\"Gyr\")\n\nplot_imu(imu_data)\n
imu_data = x_xy.imu(xs, gravity=jnp.zeros((3,)), dt=1 / hz)\nplot_imu(imu_data)\n
Accelerometer doesn't look too great! We need low-pass filtering. Two options:
imu_data = x_xy.imu(xs, gravity=jnp.zeros((3,)), dt=1 / hz, quasi_physical=True)\nplot_imu(imu_data)\n
imu_data = x_xy.imu(xs, gravity=jnp.zeros((3,)), dt=1 / hz, low_pass_filter_pos_f_cutoff=15.0, low_pass_filter_rot_alpha=0.55)\nplot_imu(imu_data)\n
from scipy.optimize import minimize\n\ndef optimize_parameters(seg: str, motion: str):\n\n # include all `fast` data in the optimization\n t1, t2 = 0.0, 500.0\n pos, rot, xs, imu_data = load_data(seg, t1, t2, motion)\n imu_data[\"acc\"] = linear_acceleration(xs, imu_data[\"acc\"])\n\n @jax.jit\n def objective(params):\n f_cutoff, alpha, offset = params\n\n # probably move about 5cm negative x-axis in local CS for e.g. segment 1\n pos_offset = x_xy.maths.rotate(x_xy.maths.rotate(pos, rot) + jnp.array([offset, 0, 0]), x_xy.maths.quat_inv(rot))\n xs_offset = xs.replace(pos=pos_offset)\n imu = x_xy.imu(xs_offset, jnp.zeros((3,)), 1 / hz, low_pass_filter_pos_f_cutoff=f_cutoff, low_pass_filter_rot_alpha=alpha)\n\n return jnp.mean((imu_data[\"acc\"] - imu[\"acc\"])**2) + jnp.mean((imu_data[\"gyr\"] - imu[\"gyr\"])**2)\n\n return minimize(objective, jnp.array([5.0, 1.0, 0.0]), method=\"Nelder-Mead\")\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"fast\"))\n
\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.7932413816452026\n x: [ 1.135e+01 1.034e+01 1.147e-01]\n nit: 147\n nfev: 287\n final_simplex: (array([[ 1.135e+01, 1.034e+01, 1.147e-01],\n [ 1.135e+01, 1.034e+01, 1.147e-01],\n [ 1.135e+01, 1.034e+01, 1.147e-01],\n [ 1.135e+01, 1.034e+01, 1.147e-01]]), array([ 7.932e-01, 7.932e-01, 7.933e-01, 7.933e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.40395233035087585\n x: [ 1.123e+01 1.112e+01 1.159e-01]\n nit: 98\n nfev: 198\n final_simplex: (array([[ 1.123e+01, 1.112e+01, 1.159e-01],\n [ 1.123e+01, 1.112e+01, 1.159e-01],\n [ 1.123e+01, 1.112e+01, 1.159e-01],\n [ 1.123e+01, 1.112e+01, 1.159e-01]]), array([ 4.040e-01, 4.040e-01, 4.040e-01, 4.040e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.37816864252090454\n x: [ 1.190e+01 1.226e+01 1.195e-01]\n nit: 121\n nfev: 238\n final_simplex: (array([[ 1.190e+01, 1.226e+01, 1.195e-01],\n [ 1.190e+01, 1.226e+01, 1.195e-01],\n [ 1.190e+01, 1.226e+01, 1.195e-01],\n [ 1.190e+01, 1.226e+01, 1.195e-01]]), array([ 3.782e-01, 3.782e-01, 3.782e-01, 3.782e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.541861355304718\n x: [ 1.131e+01 1.372e+01 1.160e-01]\n nit: 173\n nfev: 330\n final_simplex: (array([[ 1.131e+01, 1.372e+01, 1.160e-01],\n [ 1.131e+01, 1.372e+01, 1.160e-01],\n [ 1.131e+01, 1.372e+01, 1.160e-01],\n [ 1.131e+01, 1.372e+01, 1.160e-01]]), array([ 5.419e-01, 5.419e-01, 5.419e-01, 5.419e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.6123160123825073\n x: [ 1.106e+01 9.883e+00 1.211e-01]\n nit: 102\n nfev: 202\n final_simplex: (array([[ 1.106e+01, 9.883e+00, 1.211e-01],\n [ 1.106e+01, 9.883e+00, 1.211e-01],\n [ 1.106e+01, 9.883e+00, 1.211e-01],\n [ 1.106e+01, 9.883e+00, 1.211e-01]]), array([ 6.123e-01, 6.123e-01, 6.123e-01, 6.123e-01]))\n
\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"slow1\"))\n
\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.09304095804691315\n x: [ 9.910e+00 3.885e-01 1.136e-01]\n nit: 111\n nfev: 211\n final_simplex: (array([[ 9.910e+00, 3.885e-01, 1.136e-01],\n [ 9.910e+00, 3.885e-01, 1.136e-01],\n [ 9.910e+00, 3.885e-01, 1.136e-01],\n [ 9.910e+00, 3.885e-01, 1.136e-01]]), array([ 9.304e-02, 9.305e-02, 9.305e-02, 9.305e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.02368384227156639\n x: [ 1.008e+01 3.732e-01 1.332e-01]\n nit: 97\n nfev: 190\n final_simplex: (array([[ 1.008e+01, 3.732e-01, 1.332e-01],\n [ 1.008e+01, 3.732e-01, 1.332e-01],\n [ 1.008e+01, 3.732e-01, 1.332e-01],\n [ 1.008e+01, 3.732e-01, 1.332e-01]]), array([ 2.368e-02, 2.369e-02, 2.369e-02, 2.369e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.01580578088760376\n x: [ 8.666e+00 3.510e-01 1.343e-01]\n nit: 111\n nfev: 219\n final_simplex: (array([[ 8.666e+00, 3.510e-01, 1.343e-01],\n [ 8.666e+00, 3.510e-01, 1.343e-01],\n [ 8.666e+00, 3.510e-01, 1.343e-01],\n [ 8.667e+00, 3.510e-01, 1.343e-01]]), array([ 1.581e-02, 1.581e-02, 1.581e-02, 1.581e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.01700039766728878\n x: [ 8.336e+00 3.602e-01 1.210e-01]\n nit: 108\n nfev: 208\n final_simplex: (array([[ 8.336e+00, 3.602e-01, 1.210e-01],\n [ 8.336e+00, 3.601e-01, 1.210e-01],\n [ 8.335e+00, 3.601e-01, 1.210e-01],\n [ 8.335e+00, 3.602e-01, 1.210e-01]]), array([ 1.700e-02, 1.700e-02, 1.700e-02, 1.700e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.10861615836620331\n x: [ 6.784e+00 3.782e-01 5.929e-04]\n nit: 50\n nfev: 107\n final_simplex: (array([[ 6.784e+00, 3.782e-01, 5.929e-04],\n [ 6.784e+00, 3.782e-01, 5.930e-04],\n [ 6.784e+00, 3.782e-01, 5.930e-04],\n [ 6.784e+00, 3.782e-01, 5.930e-04]]), array([ 1.086e-01, 1.086e-01, 1.086e-01, 1.086e-01]))\n
\n
\n
"},{"location":"notebooks/imu_modeling/#on-what-imus-measure","title":"On \"what IMUs measure\"","text":""},{"location":"notebooks/imu_modeling/#real-world-imu","title":"Real-world IMU","text":""},{"location":"notebooks/imu_modeling/#vanilla-simulated-imu","title":"Vanilla simulated IMU","text":""},{"location":"notebooks/imu_modeling/#quasi-physical-simulation-strategy","title":"Quasi-physical simulation strategy","text":""},{"location":"notebooks/imu_modeling/#butterworth-filtering","title":"Butterworth filtering","text":""},{"location":"notebooks/imu_modeling/#optimize-low-pass-filter-parameters","title":"Optimize low-pass-filter parameters","text":""},{"location":"notebooks/knee_joint_translational_dof/","title":"Knee joint translational dof","text":"This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'x_xy[muj] @ git+https://github.com/SimiPixel/x_xy_v2'\")\n os.system(\"pip install --quiet mediapy\")\n os.system(\"pip install --quiet matplotlib\")\n os.system(\"pip install --quiet dm-haiku\")\n
import x_xy\n# automatically detects colab or not\nx_xy.utils.setup_colab_env()\n\nimport jax\nimport jax.numpy as jnp\nimport haiku as hk\nimport mediapy as media\nimport tree_utils\n
MAX_TRANSLATION = 0.1\nROM_MIN_RAD = 0.0\nROM_MAX_RAD = jnp.pi\n\ndef build_mlp_knee(key: jax.random.PRNGKey = jax.random.PRNGKey(1)):\n\n @hk.without_apply_rng\n @hk.transform\n def mlp(x):\n net = hk.nets.MLP([10, 10, 2], activation=jnp.tanh, w_init=hk.initializers.RandomNormal())\n # normalize the x input; [0, 1]\n x = (x + ROM_MIN_RAD) / (ROM_MAX_RAD - ROM_MIN_RAD)\n # center the x input; [-0.5, 0.5]\n x = (x - 0.5)\n return net(x)\n\n example_q = jnp.zeros((1,))\n params = mlp.init(key, example_q)\n\n def forward(params, q: jax.Array):\n return jax.nn.sigmoid(mlp.apply(params, q)) * MAX_TRANSLATION\n\n return params, forward\n\ndef _knee_init_joint_params(key):\n return build_mlp_knee(key)[0]\n\n\ndef transform_fn_knee(q: jax.Array, params: jax.Array) -> x_xy.Transform:\n forward = build_mlp_knee()[1]\n pos = jnp.concatenate((forward(params, q), jnp.array([0.0])))\n axis = jnp.array([0, 0, 1.0])\n rot = x_xy.maths.quat_rot_axis(axis, jnp.squeeze(q))\n return x_xy.Transform(pos, rot)\n\n\ndef draw_fn_knee(config: x_xy.MotionConfig, key_t, key_value, dt, params):\n qs = x_xy.algorithms.jcalc._draw_rxyz(config, key_t, key_value, dt, params)\n # rom constraints\n return (qs / (2 * jnp.pi) + 0.5) * (ROM_MAX_RAD - ROM_MIN_RAD) + ROM_MIN_RAD\n\nx_xy.register_new_joint_type(\"knee\", x_xy.JointModel(transform_fn_knee, rcmg_draw_fn=draw_fn_knee, init_joint_params=_knee_init_joint_params), 1, 0)\n
HIP_REVOLUTE_JOINT = True\n\nxml_str = f\"\"\"\n<x_xy>\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<geom dim=\"0.15\" type=\"xyz\"></geom>\n<body euler=\"90 90 0\" joint=\"py\" name=\"_femur\" pos=\"0.5 0.5 0.8\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<body \"frozen\"}\"=\"\" else=\"\" hip_revolute_joint=\"\" if=\"\" joint=\"{\" name=\"femur\" rz\"=\"\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.05 0.4\" euler=\"0 90 0\" mass=\"10\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"0.2 0 0.06\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0\" type=\"box\"></geom>\n</body>\n<body joint=\"knee\" name=\"tibia\" pos=\"0.4 0 0\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.04 0.4\" euler=\"0 90 0\" mass=\"10\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.2 0 0.06\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0\" type=\"box\"></geom>\n</body>\n<geom dim=\"0.025 0.2 0.05\" mass=\"5.0\" pos=\"0.45 -.1 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = x_xy.load_sys_from_str(xml_str)\n
def finalize_fn(key, qs, xs: x_xy.Transform, sys: x_xy.System):\n X = {}\n for imu in [\"imu1\", \"imu2\"]:\n xs_imu = xs.take(sys.name_to_idx(imu), axis=1)\n X[imu] = {}\n X[imu][\"pos\"] = xs_imu.pos\n X[imu][\"quat\"] = xs_imu.rot\n X[imu][\"imu\"] = x_xy.imu(xs_imu, sys.gravity, sys.dt)\n\n params = tree_utils.tree_slice(sys.links.joint_params[\"knee\"], sys.name_to_idx(\"tibia\"))\n return qs, xs, X, params\n\ndata = x_xy.build_generator(sys, x_xy.MotionConfig(t_min=0.1, t_max=0.75, T=30), finalize_fn=finalize_fn, randomize_joint_params=True, eager=True, aslist=True, seed=1, sizes=32)\n
\neager data generation: 1it [00:07, 7.23s/it]\n
\n
idx = 5\nqs, xs, X, params = data[idx]\n
import matplotlib.pyplot as plt\n\n\nphi = jnp.linspace(0.0, jnp.pi)[:, None]\n# meter -> centimeter\ntrans_x, trans_y = jax.vmap(lambda arr: build_mlp_knee()[1](params, arr))(phi).T * 100\nplt.scatter(trans_x, trans_y, c=phi, cmap=\"coolwarm\")\nplt.colorbar()\nplt.grid()\nplt.xlabel(\"x translation [cm]\")\nplt.ylabel(\"y translation [cm]\")\n
\nText(0, 0.5, 'y translation [cm]')
\n
media.show_video(x_xy.render(sys, [xs[i] for i in range(0, xs.shape(), 4)], camera=\"target\", width=1280, height=720), fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 750/750 [00:05<00:00, 133.26it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"notebooks/knee_joint_translational_dof/#registering-a-knee-joint-type","title":"Registering a Knee Joint Type","text":""},{"location":"notebooks/machine_learning/","title":"Machine learning","text":"Note
This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'ring @ git+https://github.com/SimiPixel/ring'\")\n os.system(\"pip install --quiet mediapy\")\n\nimport ring\n# automatically detects colab or not\nring.utils.setup_colab_env()\n\nimport mediapy\nimport jax.numpy as jnp\nimport tree_utils\nfrom ring import exp, sim2real, ml\n
imtp = exp.IMTP([\"seg2\", \"seg3\", \"seg4\"], sparse=True, joint_axes=True)\nexp_id = \"S_04\"\nmotion = \"thomas_fast\"\nringnet = ml.RING_ICML24()\nerrors, X, y, yhat, xs, xs_noimu = exp.benchmark_fn(imtp, exp_id, motion, ringnet, warmup=5.0)\nsys = imtp.sys(exp_id)\nframes = sys.render_prediction(xs, yhat, stepframe=4, transparent_segment_to_root=False, width=640, height=480, camera=\"c\", \n add_cameras={-1: '<camera mode=\"targetbody\" name=\"c\" pos=\".5 -.5 1.25\" target=\"3\"></camera>',})\n
\nDetected the following sampling rates from `X`: 100.0\n
\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1150/1150 [00:09<00:00, 126.31it/s]\n
\n
errors\n
\n{'seg2': {'mae': 4.0391045, 'std': 2.5817297},\n 'seg3': {'mae': 5.3936434, 'std': 2.996799},\n 'seg4': {'mae': 4.5474825, 'std': 2.3526802}}
\n
mediapy.show_video(frames, fps=25.0)\n
This browser does not support the video tag. \n
"},{"location":"notebooks/magnetometer_modeling/","title":"Magnetometer modeling","text":"from x_xy.subpkgs import exp\nimport matplotlib.pyplot as plt\nimport jax\nimport x_xy\nimport jax.numpy as jnp\nimport numpy as np\n\nhz = 100\nmarkerMap = {\n \"seg1\": 2,\n \"seg5\": 2,\n \"seg2\": 1,\n \"seg3\": 2,\n \"seg4\": 4\n}\n\ndef load_data(seg: str, t1: float, t2: float, motion: str = \"fast\"):\n\n data = exp.load_data(\"S_06\", motion, resample_to_hz=hz)[seg]\n\n # extract a small window from long time series for plotting\n pos, rot, imu_data = jax.tree_map(lambda arr: arr[int(t1 * hz): int(t2 * hz)], \n (data[f\"marker{markerMap[seg]}\"], data[\"quat\"], data[\"imu_rigid\"]))\n rot = x_xy.maths.quat_inv(rot)\n\n # maximal coordinates of segment, there is (almost) no sensor-to-segment orientation\n xs = x_xy.Transform.create(pos, rot)\n return pos, rot, xs, imu_data\n\n\nt1, t2 = 3.0, 9.0\npos, rot, xs, imu_data = load_data(\"seg1\", t1, t2)\n
def plot(*mag_data):\n mag_data = jax.tree_map(lambda arr: arr[:-100], mag_data)\n _, axes = plt.subplots(1, len(mag_data), figsize=(len(mag_data)*6, 4))\n axes = [axes] if not isinstance(axes, np.ndarray) else axes\n ts = jnp.arange(t1, t2 - 1.0, step=1 / hz)\n\n for i, mag in enumerate(mag_data):\n axes[i].plot(ts, mag, label=[\"x\", \"y\", \"z\"])\n axes[i].grid()\n axes[i].set_xlabel(\"time [s]\")\n\n axes[0].legend()\n\nplot(imu_data[\"mag\"])\n
imu_data_sim = x_xy.imu(xs, jnp.array([0, 0, 9.81]), 1/hz, jax.random.PRNGKey(1), has_magnetometer=True, low_pass_filter_rot_alpha=0.5)\nplot(imu_data[\"mag\"], imu_data_sim[\"mag\"])\n
from scipy.optimize import minimize\n\ndef optimize_parameters(seg: str, motion: str):\n t1, t2 = 0.0, 500.0\n pos, rot, xs, imu_data = load_data(seg, t1, t2, motion)\n\n @jax.jit\n def objective(params):\n magvec= params\n #alpha = np.clip(alpha, 0.0, 1.0)\n\n imu_sim = x_xy.imu(xs, jnp.zeros((3,)), 1 / hz, \n low_pass_filter_rot_alpha=0.5, magvec=magvec, has_magnetometer=True)\n\n return jnp.mean((imu_data[\"mag\"] - imu_sim[\"mag\"])**2)\n\n res = minimize(objective, jnp.array([0.0, .7, -.7]), method=\"Nelder-Mead\")\n\n perfect = np.array([0, res.x[1], res.x[2]])\n perfect /= np.linalg.norm(perfect)\n dip_angle = np.arctan2(perfect[1], perfect[2])\n return res.x, np.linalg.norm(res.x), np.rad2deg(dip_angle) - 90\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"slow1\"))\n
\n(array([-0.05669107, 0.13636727, -0.56856133]), 0.5874282070698279, 76.5125997928424)\n(array([ 0.02870585, 0.14479726, -0.5529681 ]), 0.5723320607239446, 75.32629421386511)\n(array([ 0.07342922, 0.27993262, -0.6070893 ]), 0.6725411131056166, 65.24528171284635)\n(array([ 0.06965261, 0.12674702, -0.66338645]), 0.6789682416674281, 79.18339336223758)\n(array([-0.02896293, 0.24820061, -0.55680701]), 0.6103084782009606, 65.9747430877396)\n
\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"fast\"))\n
\n(array([-0.08539633, 0.15602869, -0.49032469]), 0.5215896749593268, 72.34814710540633)\n(array([ 0.05422703, 0.13053918, -0.2375643 ]), 0.27643777497021754, 61.211650723518005)\n(array([ 0.17069941, 0.16292433, -0.49904502]), 0.5520222414783663, 71.9195779531855)\n(array([ 0.03610723, 0.06886188, -0.5142856 ]), 0.5201301476029805, 82.37356382782355)\n(array([-0.13417971, 0.32559843, -0.40739543]), 0.5385067931956347, 51.36746475752713)\n
\n
Test optimized magnetic field vector
pos, rot, xs, imu_data = load_data(\"seg1\", t1, t2, \"fast\")\nimu_data_sim = x_xy.imu(xs, jnp.array([0, 0, 9.81]), 1/hz, has_magnetometer=True, low_pass_filter_rot_alpha=0.56,\n magvec=jnp.array([-0.08957149, 0.17059967, -0.59387128]))\nplot(imu_data[\"mag\"], imu_data_sim[\"mag\"])\n
pos, rot, xs, imu_data = load_data(\"seg1\", t1, t2, \"slow1\")\nimu_data_sim = x_xy.imu(xs, jnp.array([0, 0, 9.81]), 1/hz, has_magnetometer=True, low_pass_filter_rot_alpha=0.5,\n magvec=jnp.array([-0.05896413, 0.14859727, -0.6037423 ]), noisy=True, key=jax.random.PRNGKey(7))\nplot(imu_data[\"mag\"], imu_data_sim[\"mag\"])\n
\n
"},{"location":"notebooks/magnetometer_modeling/#magnetometer-modeling","title":"Magnetometer modeling","text":""},{"location":"notebooks/magnetometer_modeling/#real-world-magnetic-field","title":"Real-world Magnetic-field","text":""},{"location":"notebooks/magnetometer_modeling/#optimize-magnetic-field-vector","title":"Optimize Magnetic Field Vector","text":""},{"location":"notebooks/morph_system/","title":"Morph system","text":"import x_xy\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport mediapy as media\n\ndef show_video(sys, xs: x_xy.Transform) -> None:\n assert sys.dt == 0.01\n # only render every fourth to get a framerate of 25 fps\n frames = x_xy.render(sys, [xs[i] for i in range(0, xs.shape(), 4)], camera=\"targetfar\", height=480, width=640)\n # convert rgba to rgb\n frames = [frame[..., :3] for frame in frames]\n media.show_video(frames, fps=25)\n
In this system the middle segment seg2
acts as \"anchor\".
xml_str = \"\"\"\n<x_xy model=\"three_segment_kinematic_chain\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg2\" pos=\"0 0 1\">\n<geom color=\"self\" dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg1\">\n<geom color=\"self\" dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.5 0 0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rz\" name=\"seg3\" pos=\"1 0 0\">\n<geom color=\"self\" dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.5 0 -0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = x_xy.load_sys_from_str(xml_str)\n\ngen = x_xy.build_generator(sys, x_xy.MotionConfig(T=10.0, t_max=1.5, dang_max_free_spherical=0.1, dpos_max=0.1), _compat=True)\n_, xs = gen(jax.random.PRNGKey(1))\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250/250 [00:01<00:00, 162.68it/s]\n
\n
This browser does not support the video tag. Can you see what i mean? The middle segment has all the \"global rotation and translation\".
Let's move the anchor to seg1
but without changing the xml syntax. This can be done with the subpackage sys_composer
.
from x_xy.subpkgs import sys_composer\n
# the new parents of seg2, seg1, imu1, seg3, imu2 are ...\nnew_parents = [\"seg1\", -1, \"seg1\", \"seg2\", \"seg3\"]\nsys = sys_composer.morph_system(sys, new_parents=new_parents)\n\ngen = x_xy.build_generator(sys, x_xy.MotionConfig(T=10.0, t_max=1.5, dang_max_free_spherical=0.1, dpos_max=0.1), _compat=True)\n_, xs = gen(jax.random.PRNGKey(1))\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250/250 [00:01<00:00, 147.05it/s]\n
\n
This browser does not support the video tag. Pretty cool, ha? :)
\n
"},{"location":"notebooks/morph_system/#different-anchors-explains-sys_composermorph_system","title":"Different Anchors (explains sys_composer.morph_system)","text":""},{"location":"notebooks/motion_artifact_rejection/","title":"Motion artifact rejection","text":"This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'x_xy[muj] @ git+https://github.com/SimiPixel/x_xy_v2'\")\n os.system(\"pip install --quiet mediapy\")\n
import x_xy\n# automatically detects colab or not\nx_xy.utils.setup_colab_env()\n\nimport jax\nimport jax.numpy as jnp\n\nimport mediapy as media\n\ndef show_video(sys, xs, **kwargs):\n media.show_video(x_xy.render(sys, [xs[i] for i in range(0, xs.shape(), 4)], camera=\"target\", width=640, height=480, **kwargs), fps=25)\n
knee_xml_str = \"\"\"\n<x_xy model=\"knee_flexible_imus\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"5 5 5 25 25 25\" joint=\"free\" name=\"femur\" pos=\"0.5 0.5 0.3\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.05 0.4\" euler=\"0 90 0\" mass=\"1\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"0.2 0 0.05\" pos_max=\"0.35 0 0\" pos_min=\"0.05 0 0\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0.1\" type=\"box\"></geom>\n</body>\n<body damping=\"3\" joint=\"ry\" name=\"tibia\" pos=\"0.4 0 0\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.04 0.4\" euler=\"0 90 0\" mass=\"1\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.2 0 0.05\" pos_max=\"0.35 0 0\" pos_min=\"0.05 0 0\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0.1\" type=\"box\"></geom>\n</body>\n<geom dim=\"0.025 0.05 0.2\" mass=\"0\" pos=\"0.45 0 .1\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = x_xy.load_sys_from_str(knee_xml_str)\n
media.show_image(x_xy.render(sys, camera=\"target\", height=480, width=640)[0])\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1/1 [00:00<00:00, 14.47it/s]\n
\n
T = 20.0\nT_initial_nomotion = 2.0\n\nconfig = x_xy.MotionConfig(t_min=0.1, t_max=0.75, T=T, ang0_min=0.0, ang0_max=0.0, pos_min=-1.0, pos_max=1.0, dpos_max=0.5)\nconfig = x_xy.join_motionconfigs([config.to_nomotion_config(), config], [T_initial_nomotion])\n\n(X, y), (_, qs, xs, sys_mod) = x_xy.build_generator(sys, config, imu_motion_artifacts=True, dynamic_simulation=True, eager=True, \n aslist=True, seed=1, sizes=1, keep_output_extras=True, imu_motion_artifacts_kwargs=dict(hide_injected_bodies=False))[0]\n
\n/Users/simon/Documents/PYTHON/x_xy_v2/x_xy/algorithms/generator/motion_artifacts.py:80: UserWarning: `sys.links.joint_params` has been set to zero, this might lead to unexpected behaviour unless you use `randomize_joint_params`\n warnings.warn(\n/Users/simon/Documents/PYTHON/x_xy_v2/x_xy/algorithms/generator/base.py:184: UserWarning: `imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`\n warnings.warn(\neager data generation: 1it [00:28, 28.97s/it]\n
\n
show_video(sys_mod, xs, show_floor=False)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 500/500 [00:01<00:00, 315.52it/s]\n
\n
This browser does not support the video tag. sys_frozen = sys_mod.freeze(\"tibia\").freeze(\"femur\")\n\ndef freeze_x(q_obs):\n q_frozen = jnp.concatenate(tuple(q_obs[:, sys_mod.idx_map(\"q\")[name]] for name in [\"_imu1\", \"imu1\", \"_imu2\", \"imu2\"]), axis=-1)\n return jax.vmap(lambda q: x_xy.algorithms.forward_kinematics_transforms(sys_frozen, q)[0])(q_frozen)\n
show_video(sys_frozen, freeze_x(qs))\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 500/500 [00:02<00:00, 172.52it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"notebooks/visualisation/","title":"Visualisation","text":"import ring\nfrom ring import exp\nimport mediapy as media\nimport jax\n
sys_str = \"\"\"\n<x_xy>\n<worldbody>\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<body joint=\"free\" name=\"seg\" pos=\"0 0 .5\">\n<geom color=\"dustin_exp_blue\" dim=\"0.15 0.075 0.05\" mass=\"0.2\" pos=\"0.03 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu\" pos=\"0.0 0.0 0.03\">\n<geom color=\"dustin_exp_orange\" dim=\"0.05 0.03 0.02\" mass=\"0.1\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n
sys = ring.System.create(sys_str)\n
(X, y), (key, q, x, _) = ring.RCMG(sys, keep_output_extras=True).to_list()[0]\n
\neager data generation: 1it [00:01, 1.58s/it]\n
\n
media.show_video(sys.render(x, width=640, height=480, camera=\"target\", render_every_nth=4), fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1500/1500 [00:08<00:00, 167.58it/s]\n
\n
This browser does not support the video tag. exp_data = exp.load_data(exp_id=\"S_06\", motion_start=\"fast\")\n
exp_data.keys()\n
\ndict_keys(['seg1', 'seg2', 'seg3', 'seg4', 'seg5'])
\n
exp_data[\"seg1\"].keys()\n
\ndict_keys(['imu_flex', 'imu_rigid', 'marker1', 'marker2', 'marker3', 'marker4', 'quat'])
\n
segment = \"seg2\"\nomc_data_sys = {\n \"seg\": {\n \"pos\": exp_data[segment][\"marker1\"],\n \"quat\": exp_data[segment][\"quat\"],\n },\n \"imu\": {\n \"quat\": exp_data[segment][\"quat\"],\n }\n}\n
omc_data_sys\n
\n{'seg': {'pos': Array([[ 0.26832035, 1.1925832 , -0.06244465],\n [ 0.268319 , 1.1925788 , -0.06244366],\n [ 0.26831627, 1.1925731 , -0.06244285],\n ...,\n [ 0.19899347, 1.2143577 , -0.06264979],\n [ 0.19898805, 1.2143621 , -0.06264362],\n [ 0.19897974, 1.214368 , -0.06263389]], dtype=float32),\n 'quat': Array([[ 0.95449424, 0.10144146, -0.01664882, -0.27995223],\n [ 0.95449567, 0.10146226, -0.01665274, -0.27993944],\n [ 0.95449716, 0.10148306, -0.01665665, -0.27992663],\n ...,\n [ 0.9600311 , -0.01755334, 0.02215116, -0.2784628 ],\n [ 0.9600333 , -0.01756094, 0.02225687, -0.27844635],\n [ 0.96003544, -0.01756854, 0.02236257, -0.27842987]], dtype=float32)},\n 'imu': {'quat': Array([[ 0.95449424, 0.10144146, -0.01664882, -0.27995223],\n [ 0.95449567, 0.10146226, -0.01665274, -0.27993944],\n [ 0.95449716, 0.10148306, -0.01665665, -0.27992663],\n ...,\n [ 0.9600311 , -0.01755334, 0.02215116, -0.2784628 ],\n [ 0.9600333 , -0.01756094, 0.02225687, -0.27844635],\n [ 0.96003544, -0.01756854, 0.02236257, -0.27842987]], dtype=float32)}}
\n
x = ring.sim2real.xs_from_raw(sys, omc_data_sys)\n\n# vectorize this function over time\n@jax.vmap\ndef update_position_vector_of_imu(x):\n state = ring.State.create(sys, x=x)\n # populate minimal coordinates `state.q` from maximal coordinates `state.x`\n state = ring.algorithms.inverse_kinematics(sys, state)\n # re-calculate maximal coordiantes `state.x` from minimal coordinates `state.q`\n # this uses the position vector specified in the system (and so the xml file)\n # to produce an offset between IMu and segment geom box\n _, state = ring.algorithms.forward_kinematics(sys, state)\n return state.x\n\nx = update_position_vector_of_imu(x)\n
media.show_video(sys.render(x, width=640, height=480, camera=\"target\", render_every_nth=4), fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1075/1075 [00:07<00:00, 134.95it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"prism/ss_23_marcel_thomas/notebook/","title":"Notebook","text":"import x_xy\nimport jax\nimport jax.numpy as jnp\nimport jax.random as random\nfrom x_xy.subpkgs.ml import rnno, callbacks, train, load\nfrom x_xy.subpkgs import sim2real, sys_composer\nimport tree_utils\nimport matplotlib.pyplot as plt\nimport mediapy as media\n
three_seg_rigid = r\"\"\"\n<x_xy model=\"three_seg_rigid\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg2\">\n<geom color=\"red\" dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n<body joint=\"rsry\" name=\"seg1\" pos=\"0 0 0\">\n<geom color=\"yellow\" dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"-0.1 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.1 0.0 0.03\">\n<geom color=\"green\" dim=\"0.05 0.01 0.01\" mass=\"2\" pos=\"0 0 0\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rsrz\" name=\"seg3\" pos=\"0.2 0 0\">\n<geom color=\"blue\" dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.1 0.0 0.03\">\n<geom color=\"green\" dim=\"0.05 0.01 0.01\" mass=\"2\" pos=\"0 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n<defaults>\n<geom color=\"1 0.8 0.7 1\" edge_color=\"black\"></geom>\n</defaults>\n</x_xy>\n\"\"\"\n
dustin_exp_xml_seg1 = r\"\"\"\n<x_xy model=\"dustin_exp\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg1\">\n<geom dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"-0.1 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg2\">\n<geom dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n<body joint=\"rz\" name=\"seg3\" pos=\"0.2 0 0\">\n<geom dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n
# Helper function - Creates an array of values x <- [0, 1] which may be multiplied to another sequence.\ndef motion_amplifier(\n time : float,\n sampling_rate : float,\n key_rigid_phases : jax.Array,\n n_rigid_phases=3,\n rigid_duration_cov=jnp.array([0.02] * 3),\n transition_cov=jnp.array([0.1] * 3)\n) -> jax.Array:\n error_msg = \"motion_amplifier: There must be a variance for each rigid phase!\"\n assert rigid_duration_cov.shape == (n_rigid_phases,) == transition_cov.shape, error_msg\n n_frames = int(time / sampling_rate)\n key_rigid_means, key_rigid_variances, key_slope_down_variances, key_slope_up_variances = random.split(key_rigid_phases, 4)\n\n # Calculate center points of rigid phases\n means = jnp.sort(random.uniform(key_rigid_means, shape=(n_rigid_phases, 1), minval=0, maxval=n_frames).T)\n\n # Calculate durations, which is twice the rigid distance from the center points for each rigid phase.\n rigid_distances = jnp.abs(random.multivariate_normal(\n key_rigid_variances, mean=jnp.zeros_like(means), cov=jnp.diag((rigid_duration_cov * n_frames)**2)))\n\n # Calculate transition durations\n transition_slowdown_durations = jnp.abs(random.multivariate_normal(\n key_slope_down_variances, mean=jnp.zeros_like(means), cov=jnp.diag((transition_cov * n_frames)**2)\n ))\n transition_speedup_durations = jnp.abs(random.multivariate_normal(\n key_slope_up_variances, mean=jnp.zeros_like(means), cov=jnp.diag((transition_cov * n_frames)**2)\n ))\n\n # Phase start and end points\n rigid_starts = (means - rigid_distances).astype(int).flatten()\n rigid_ends = (means + rigid_distances).astype(int).flatten()\n starts_slowing = (means - rigid_distances - transition_slowdown_durations).astype(int).flatten()\n ends_moving = (means + rigid_distances + transition_speedup_durations).astype(int).flatten()\n\n # Create masks\n def create_mask(start, end):\n nonlocal n_frames\n return jnp.where(jnp.arange(n_frames) < start, 1, 0) + jnp.where(jnp.arange(n_frames) >= end, 1, 0)\n\n mask = jax.vmap(create_mask)\n rigid_mask = jnp.prod(mask(rigid_starts, rigid_ends), axis=0)\n slowdown_masks = mask(starts_slowing, rigid_starts).astype(float)\n speedup_masks = mask(rigid_ends, ends_moving).astype(float)\n\n # We have to define an inline function in order to make this code JIT-able\n def linsp(mask, start, end, begin_val, carry_fun):\n range = end - start\n def true_fun(carry, x): return (carry_fun(carry, range), 1 - carry)\n def false_fun(carry, x): return (carry, x)\n def f(carry, x): return jax.lax.cond(\n x == 0, true_fun, false_fun, *(carry, x))\n return jax.lax.scan(f, begin_val, mask)[1]\n\n linsp_desc = jax.vmap(lambda m, s1, s2: linsp( m, s1, s2, 0.0, lambda carry, range: carry + 1/range))\n slowdown_mask = jnp.prod(linsp_desc(slowdown_masks, starts_slowing, rigid_starts), axis=0)\n\n linsp_asc = jax.vmap(lambda m, s1, s2: linsp(m, s1, s2, 1.0, lambda carry, range: carry - 1/range))\n speedup_mask = jnp.prod(linsp_asc(speedup_masks, rigid_ends, ends_moving), axis=0)\n\n return jnp.min(jnp.stack([rigid_mask, slowdown_mask, speedup_mask]), axis=0)\n
# Random generator: Uses the motion_amplifier to dampen/null the randomly generated angles.\ndef random_angles_with_rigid_phases_over_time(\n key_t,\n key_ang,\n dt,\n key_rigid_phases,\n n_rigid_phases=3,\n rigid_duration_cov=jnp.array([0.02] * 3),\n transition_cov=jnp.array([0.1] * 3),\n config: x_xy.algorithms.MotionConfig=x_xy.algorithms.MotionConfig()\n) -> jax.Array:\n\n mask = motion_amplifier(\n config.T,\n dt,\n key_rigid_phases,\n n_rigid_phases,\n rigid_duration_cov,\n transition_cov)\n\n qs = x_xy.algorithms.random_angle_over_time(\n key_t=key_t,\n key_ang=key_ang,\n ANG_0=config.ang0_max,\n dang_min=config.dang_min,\n dang_max=config.dang_max,\n delta_ang_min=config.delta_ang_min,\n delta_ang_max=config.delta_ang_max,\n t_min=config.t_min,\n t_max=config.t_max,\n T=config.T,\n Ts=dt,\n randomized_interpolation=config.randomized_interpolation_angle,\n range_of_motion=config.range_of_motion_hinge,\n range_of_motion_method=config.range_of_motion_hinge_method\n )\n\n # derivate qs\n qs_diff = jnp.diff(qs, axis=0)\n\n # mulitply with motion amplifier\n qs_diff = qs_diff * mask[:-1]\n\n # integrate qs_diff\n qs_rigid_phases = jnp.concatenate((qs[0:1], jnp.cumsum(qs_diff, axis=0)))\n return qs_rigid_phases\n
BEST_RUN = (1, jnp.array([0.02]), jnp.array([0.1]))\nMANY_TINY_STOPS = (30, jnp.array([0.001] * 30), jnp.array([0.0001] * 30))\n##################################################################################\n# Define your own problem configuration here :) #\n\nPROBLEM = BEST_RUN # <- Change this assignment to use it.\n##################################################################################\n\ndef define_joints():\n def _draw_sometimes_rigid(\n config: x_xy.algorithms.MotionConfig, key_t: jax.Array, key_value: jax.Array, dt : float, joint_params : jax.Array\n ) -> jax.Array:\n key_t, key_rigid_phases = jax.random.split(key_t)\n return random_angles_with_rigid_phases_over_time(\n key_t=key_t,\n key_ang=key_value,\n dt=dt,\n key_rigid_phases=key_rigid_phases,\n n_rigid_phases=PROBLEM[0],\n rigid_duration_cov=PROBLEM[1],\n transition_cov=PROBLEM[2],\n config=config\n )\n\n def _rxyz_transform(q, _, axis):\n q = jnp.squeeze(q)\n rot = x_xy.maths.quat_rot_axis(axis, q)\n return x_xy.base.Transform.create(rot=rot)\n\n rsrx_joint = x_xy.algorithms.JointModel(\n lambda q, _: _rxyz_transform(q, _, jnp.array([1.0, 0, 0])), [None], rcmg_draw_fn=_draw_sometimes_rigid\n )\n rsry_joint = x_xy.algorithms.JointModel(\n lambda q, _: _rxyz_transform(q, _, jnp.array([0, 1.0, 0])), [None], rcmg_draw_fn=_draw_sometimes_rigid\n )\n rsrz_joint = x_xy.algorithms.JointModel(\n lambda q, _: _rxyz_transform(q, _, jnp.array([0, 0, 1.0])), [None], rcmg_draw_fn=_draw_sometimes_rigid\n )\n try:\n x_xy.algorithms.register_new_joint_type(\"rsrx\", rsrx_joint, 1)\n x_xy.algorithms.register_new_joint_type(\"rsry\", rsry_joint, 1)\n x_xy.algorithms.register_new_joint_type(\"rsrz\", rsrz_joint, 1)\n except AssertionError:\n print(\"Warning: Joints have already been registered!\")\n\ndefine_joints()\n
Note: it is also possible to support multiple problems at the same time, by implementing them as seperate joint types, or by injecting the x_xy.algorithms.MotionConfig
class e.g. by inheritance.
After we defined the joint type, we can load the system:
sys_rigid = x_xy.io.load_sys_from_str(three_seg_rigid)\nsys_inference = x_xy.io.load_sys_from_str(dustin_exp_xml_seg1)\n
def finalize_fn_imu_data(key, q, x, sys):\n imu_seg_attachment = {\"imu1\": \"seg1\", \"imu2\": \"seg3\"}\n\n X = {}\n for imu, seg in imu_seg_attachment.items():\n key, consume = jax.random.split(key)\n X[seg] = x_xy.algorithms.imu(\n x.take(sys.name_to_idx(imu), 1), sys.gravity, sys.dt, consume, True\n )\n return X\n\n\ndef finalize_fn_rel_pose_data(key, _, x, sys):\n y = x_xy.algorithms.rel_pose(sys_scan=sys_inference, xs=x, sys_xs=sys)\n return y\n\ndef finalize_fn(key, q, x, sys):\n X = finalize_fn_imu_data(key, q, x, sys)\n # Since no IMU is attached to seg2, we need to provide dummy data.\n X[\"seg2\"] = tree_utils.tree_zeros_like(X[\"seg1\"])\n y = finalize_fn_rel_pose_data(key, q, x, sys)\n return X, y\n
The generated data comes is returned in the tuple \\((\\mathbf{X}, \\mathbf{y})\\), with \\(\\mathbf{X}\\) being the generated IMU accelorometer and gyroscope data and \\(\\mathbf{y}\\) the orientation of each segment, in form of a unit quaternion.
def setup_fn_seg2(key, sys: x_xy.base.System) -> x_xy.base.System:\n def replace_pos(transforms, new_pos, name: str):\n i = sys.name_to_idx(name)\n return transforms.index_set(i, transforms[i].replace(pos=new_pos))\n\n def draw_pos_uniform(key, pos_min, pos_max):\n key, c1, c2, c3 = jax.random.split(key, num=4)\n pos = jnp.array(\n [\n jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),\n jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),\n jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),\n ]\n )\n return key, pos\n\n ts = sys.links.transform1\n\n # seg1 relative to seg2\n key, pos = draw_pos_uniform(key, [-0.3, -0.02, -0.02], [-0.05, 0.02, 0.02])\n ts = replace_pos(ts, pos, \"seg1\")\n\n # imu1 relative to seg1\n key, pos = draw_pos_uniform(\n key, [-0.25, -0.05, -0.05], [-0.05, 0.05, 0.05])\n ts = replace_pos(ts, pos, \"imu1\")\n\n # seg3 relative to seg2\n key, pos = draw_pos_uniform(key, [0.05, -0.02, -0.02], [0.3, 0.02, 0.02])\n ts = replace_pos(ts, pos, \"seg3\")\n\n # imu2 relative to seg2\n key, pos = draw_pos_uniform(key, [0.05, -0.05, -0.05], [0.25, 0.05, 0.05])\n ts = replace_pos(ts, pos, \"imu2\")\n\n return sys.replace(links=sys.links.replace(transform1=ts))\n
With this, we can now train the model: We first define the batch size and number of epochs. For good results, a relatively large number of epochs is required, as the mean average angle error in training converges relatively late in training. Then we plug together the setup- and finalize functions in a generator function, which will provide the batched training data. A logger might also be added, such as a neptune logger. When using neptune, the environment-variables NEPTUNE_TOKEN
and NEPTUNE_PROJECT
must be set accordingly.
TRAINING_BATCH_SIZE = 80\nEPOCHS = 1500\nparams_path = \"parameters.pickle\"\nKEY_GEN = random.PRNGKey(1)\nKEY_NETWORK = random.PRNGKey(1)\n\ngen = x_xy.algorithms.build_generator(sys_rigid, x_xy.algorithms.MotionConfig(), setup_fn_seg2, finalize_fn)\ngen = x_xy.algorithms.batch_generators_lazy(gen, TRAINING_BATCH_SIZE)\n\n# Set 'upload' to True if a logger is attached.\nsave_params = callbacks.SaveParamsTrainingLoopCallback(params_path, upload=False) \n\nloggers = []\n# loggers.append(NeptuneLogger()) # You may add loggers here, e.g. a Neptune Logger\n\nnetwork = rnno.make_rnno(sys_inference)\n
WARNING! Executing this code can take a long time (due to the very high number of epochs) and will probably take up a huge portion of your memory. If you run this code on a GPU, a batch size of 80 takes more than 50 GB of VRAM, so if the execution fails, it might be because of missing GPU memory. To circumvent this, the batch size can be decreased, however, the results will suffer from that.
train(gen, EPOCHS, network, loggers=loggers, callbacks=[save_params], key_generator=KEY_GEN, key_network=KEY_NETWORK)\n
def finalize_fn_inference(key, q, x, sys):\n X = finalize_fn_imu_data(key, q, x, sys)\n y = finalize_fn_rel_pose_data(key, q, x, sys)\n return X, y, x\n\n\ndef generate_inference_data(sys, config: x_xy.algorithms.MotionConfig, seed=random.PRNGKey(1,)):\n generator = x_xy.algorithms.build_generator(sys, config, finalize_fn=finalize_fn_inference)\n X, y, xs = generator(seed)\n return X, y, xs\n
To control the data generated, the MotionConfig data is used. It contains all necessary information about the to-be-generated data series, e.g. time (config.T
), except for the sampling rate, which is stored in the system object (<sys>.dt
) and set in the XML-definition. The finalize function and its return values are similiar to the training finilaize function, however, an addtitional \\(\\mathbf{xs}\\) is returned, containing the actual position and rotation. This can be used for rendering purposes later. Also, the data is not batched, as we currently are only interested in one time series.
config = x_xy.algorithms.MotionConfig()\nprint(f\"Generating data for a time series of {config.T} seconds, with a sampling rate of {1/sys_inference.dt} Hz.\")\n\n# If you are unhappy with your data series, you can alter this seed:\nseed = random.PRNGKey(1337,)\n\nX, y, xs = generate_inference_data(sys_rigid, config, seed)\n\n# Add dummy IMU data for segment 2 (which has no IMU attached)\nX[\"seg2\"] = tree_utils.tree_zeros_like(X[\"seg1\"])\n
\nGenerating data for a time series of 60.0 seconds, with a sampling rate of 100.0 Hz.\n
\n
params = load(\"parameters.pickle\")\n
Finally, we have everything we need to do inference! Let's see how our network performs...
# Run prediction:\nX_3d = tree_utils.to_3d_if_2d(X, strict=True)\ninitial_params, state = network.init(random.PRNGKey(1,), X_3d)\nyhat, _ = network.apply(params, tree_utils.add_batch_dim(state), X_3d)\nyhat = tree_utils.to_2d_if_3d(yhat, strict=True)\n\n# Plot prediction:\ndef plot_segment(segment : str, axis : str, ax):\n axis_idx = \"xyz\".index(axis)\n euler_angles_hat_seg2 = jnp.rad2deg(x_xy.maths.quat_to_euler(yhat[segment])[:,axis_idx])\n euler_angles_seg2 = jnp.rad2deg(x_xy.maths.quat_to_euler(y[segment])[:,axis_idx])\n ax.plot(euler_angles_hat_seg2, label=\"prediction\")\n ax.set_ylim((-180, 180))\n ax.set_title(f\"{segment} ({axis}-axis)\")\n ax.plot(euler_angles_seg2, label=\"truth\")\n ax.set_xlabel(\"time [s]\")\n ax.set_ylabel(\"euler angles [deg]\")\n ax.legend()\n print(f\"{segment}: medium absolute error {jnp.average(jnp.abs(euler_angles_hat_seg2 - euler_angles_seg2))} deg\")\n\nfig, axs = plt.subplots(ncols=2, figsize=(10, 4))\nplot_segment(\"seg2\", 'y', axs[0])\nplot_segment(\"seg3\", 'z', axs[1])\nplt.show()\n
\nseg2: medium absolute error 0.524849534034729 deg\nseg3: medium absolute error 0.5137953162193298 deg\n
\n
Let's also render a video of the prediction and the truth:
# Extract translations from data-generating system...\ntranslations, rotations = sim2real.unzip_xs(sys_inference, sim2real.match_xs(sys_inference, xs, sys_rigid))\nyhat_inv = jax.tree_map(lambda quat: x_xy.maths.quat_inv(quat), yhat) \n\n# ... swap rotations with predicted ones...\nrotations_hat = [] \nfor i, name in enumerate(sys_inference.link_names):\n if name in yhat_inv:\n rotations_name = x_xy.Transform.create(rot=yhat_inv[name])\n else:\n rotations_name = rotations.take(i, axis=1)\n rotations_hat.append(rotations_name)\n\n# ... and plug the positions and rotations back together.\nrotations_hat = rotations_hat[0].batch(*rotations_hat[1:]).transpose((1, 0, 2))\nxs_hat = sim2real.zip_xs(sys_inference, translations, rotations_hat)\n\n# Create combined system that shall be rendered and its transforms\nsys_render = sys_composer.inject_system(sys_rigid, sys_inference.add_prefix_suffix(suffix=\"_hat\"))\nxs_render = x_xy.Transform.concatenate(xs, xs_hat, axis=1)\n\n# Render prediction and truth:\nframes = x_xy.render(sys_render, [xs_render[i] for i in range(xs_render.shape(axis=0))], camera='target')\nmedia.show_video([frame[..., :3] for frame in frames], fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 6000/6000 [00:16<00:00, 374.76it/s]\n
\n
This browser does not support the video tag."},{"location":"prism/ss_23_marcel_thomas/notebook/#training-the-rnno-with-rigid-phases-prism-ss2023","title":"Training the RNNO with rigid phases (PRISM SS2023)","text":"In this notebook, we define a custom hinge joint, which is configured to generate pauses (no movement) inside the generated data series. We use this joint to train the RNNO and perfrom inference with the generated parameters.
"},{"location":"prism/ss_23_marcel_thomas/notebook/#defining-the-system","title":"Defining the System","text":"A system is defined in an XML structure. To read a system, an XML file may be used. It is also possible to define the system inline by using a string in XML-syntax.In the following, we define two three-segment chains:
"},{"location":"prism/ss_23_marcel_thomas/notebook/#registering-the-joint-axis","title":"Registering the joint axis","text":"For this scenario, we define two systems: One for generating data with rigid phases and one for inference. To generate the random data with rigid phases, we first have to register a joint type, that allows for the creation of such data. We call this joint 'rsr\\<x|y|z>', a hinge joint that produces random sometimes rigid data, and turns around the respective axis \\(x\\), \\(y\\) or \\(z\\) in its frame. </x|y|z>
"},{"location":"prism/ss_23_marcel_thomas/notebook/#generating-random-data","title":"Generating random data","text":"The random data is generated by the following functions:
"},{"location":"prism/ss_23_marcel_thomas/notebook/#defining-the-random-joint-function","title":"Defining the random joint function","text":"First of all, we have to define our problem. This means, parameterzing the random function. Two possible scenarios are implemented below: \"BEST_RUN\" and \"MANY_TINY_STOPS\", both of which achieved adequate results. The problems are defined as \\(P=(N, \\mathbf{\\sigma}_{r}, \\mathbf{\\sigma}_{tr})\\), with \\(N\\) being the number of rigid phases, \\(\\mathbf{\\sigma}_r\\) the covariance used for calculating the length of each rigid phase and \\(\\mathbf{\\sigma}_{tr}\\) for the length of each transition phase respectively. It also holds that \\(\\mathbf{\\sigma}_r, \\mathbf{\\sigma}_{tr} \\in \\mathbb{R}^N\\), with each entry being the variance for exactly one rigid phase.
"},{"location":"prism/ss_23_marcel_thomas/notebook/#generating-raw-data","title":"Generating raw data","text":"For both training and inference, we first need a set of raw data. In our example, sys_rigid
is used to generate the problem-specific data for each IMU. This data will be used for training and later by sys_inference
to estimate the position and orientation of seg2
, which has no IMU attached.
Before we begin with the actual training, we first define a setup function. This is called before training on each time series. The function below alters the length of segments and the position of the IMUs of the system, to simulate inaccuracies, e.g. when dealing with experimental data.
"},{"location":"prism/ss_23_marcel_thomas/notebook/#infering-data","title":"Infering data","text":""},{"location":"prism/ss_23_marcel_thomas/notebook/#inference","title":"Inference","text":"To do inference, we first need to load the parameters (weights) of our model.
"},{"location":"prism/ss_23_moritz/notebook/","title":"Notebook","text":"import jax\nimport jax.numpy as jnp\nimport tree_utils\nfrom jax.nn import softmax\nimport matplotlib.pyplot as plt\nimport mediapy\n\nimport x_xy\nfrom x_xy.subpkgs import ml, sim2real, sys_composer\n
Set the batch size and number of training episodes according to the available hardware.
BATCHSIZE = 32\nNUM_TRAINING_EPISODES = 1500\n
sys_str = r\"\"\"\n<x_xy model=\"three_segment_kinematic_chain\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<defaults>\n<geom color=\"orange\"></geom>\n</defaults>\n<worldbody>\n<body joint=\"free\" name=\"seg2\" pos=\"0 0 2\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg1\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.5 0 0.125\">\n<geom color=\"red\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rz\" name=\"seg3\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.5 0 -0.125\">\n<geom color=\"red\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\nsys = x_xy.io.load_sys_from_str(sys_str)\n
dustin_exp_xml_seg1 = r\"\"\"\n<x_xy model=\"dustin_exp\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<defaults>\n<geom color=\"white\"></geom>\n</defaults>\n<worldbody>\n<body joint=\"free\" name=\"seg1\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg2\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"rz\" name=\"seg3\" pos=\"0.2 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\nsys_inference = x_xy.io.load_sys_from_str(dustin_exp_xml_seg1)\n
def finalise_fn(key: jax.Array, q: jax.Array, xs: x_xy.Transform, sys: x_xy.System):\n def xs_by_name(name: str):\n return xs.take(sys.name_to_idx(name), axis=1)\n\n key, *consume = jax.random.split(key, 3)\n\n # the input X to our RNNo is the IMU data of segments 1 and 3\n X = {\n \"seg1\": x_xy.imu(xs_by_name(\"imu1\"), sys.gravity, sys.dt, consume[0], True),\n \"seg3\": x_xy.imu(xs_by_name(\"imu2\"), sys.gravity, sys.dt, consume[1], True),\n }\n\n # seg2 has no IMU, but we still need to make an entry in our X\n X[\"seg2\"] = tree_utils.tree_zeros_like(X[\"seg1\"])\n\n # the output of the RNNo is the estimated relative poses of our segments\n y = x_xy.algorithms.rel_pose(sys_scan=sys_inference, xs=xs, sys_xs=sys)\n\n return X, y\n\nconfig = x_xy.algorithms.MotionConfig(dpos_max=0.3, ang0_min=0.0, ang0_max=0.0)\n\ngen = x_xy.build_generator(sys, config, finalize_fn=finalise_fn)\ngen = x_xy.batch_generator(gen, BATCHSIZE)\n
def make_loss_fn(beta):\n def metric_fn(q, q_hat):\n return x_xy.maths.angle_error(q, q_hat) ** 2\n\n if beta is not None:\n\n def loss_fn(q, q_hat):\n # q.shape == q_hat.shape == (1000, 4)\n angles = metric_fn(q, q_hat)\n\n factors = angles.shape[-1] * softmax(\n beta * jax.lax.stop_gradient(angles), axis=-1\n )\n\n errors = factors * angles\n\n return errors\n\n else:\n loss_fn = metric_fn\n\n return loss_fn\n
beta
determines the strength of our weighting: the larger beta, the more relative weight we put on the larger errors, while beta = 0.0
makes the scaling factors uniform one and gives us back our unweighted errors. Alternatively beta = None
bypasses the scaling altogether.
beta = 1.0\n
rnno = ml.make_rnno(sys_inference)\n\nloss_fn = make_loss_fn(beta)\n\nsave_params = ml.callbacks.SaveParamsTrainingLoopCallback(\n \"parameters.pickle\", upload=False\n)\n\nml.train(gen, NUM_TRAINING_EPISODES, rnno, callbacks=[save_params], loss_fn=loss_fn)\n
To visualise our network, we can render it using mediapy. First we generate some motion data.
gen = x_xy.build_generator(sys, config)\n\nkey = jax.random.PRNGKey(1)\n\nq, xs = gen(key)\n
We need to again bring the motion data in the correct form for our RNNo and can then run inference of the generated data.
params = ml.load(\"parameters.pickle\")\n\nX, y = finalise_fn(key, q, xs, sys)\n\nX = tree_utils.add_batch_dim(X)\n\n_, state = rnno.init(key, X)\n\nstate = tree_utils.add_batch_dim(state)\n\ny_hat, _ = rnno.apply(params, state, X)\ny_hat = tree_utils.to_2d_if_3d(y_hat, strict=True)\n
First we want to plot the angle error for both segment 2 and segment 3 over time.
y[\"seg2\"][:10]\n
y_hat[\"seg2\"]\n
fig, ax = plt.subplots()\n\nangle_error2 = jnp.rad2deg(x_xy.maths.angle_error(y[\"seg2\"], y_hat[\"seg2\"]))\nangle_error3 = jnp.rad2deg(x_xy.maths.angle_error(y[\"seg3\"], y_hat[\"seg3\"]))\n\nT = jnp.arange(angle_error2.size) * sys_inference.dt\n\nax.plot(T, angle_error2, label=\"seg2\")\nax.plot(T, angle_error3, label=\"seg3\")\n\nax.set_xlabel(\"time [s]\")\nax.set_ylabel(\"abs. angle error [deg]\")\n\nax.legend()\n\nplt.show()\n
Next we have to create an xs_hat
of the estimated orientations, so that we can render them.
# Extract translations from data-generating system...\ntranslations, rotations = sim2real.unzip_xs(\n sys_inference, sim2real.match_xs(sys_inference, xs, sys)\n)\n\ny_hat_inv = jax.tree_map(lambda quat: x_xy.maths.quat_inv(quat), y_hat) \n\n# ... swap rotations with predicted ones...\nrotations_hat = [] \nfor i, name in enumerate(sys_inference.link_names):\n if name in y_hat_inv:\n rotations_name = x_xy.Transform.create(rot=y_hat_inv[name])\n else:\n rotations_name = rotations.take(i, axis=1)\n rotations_hat.append(rotations_name)\n\n# ... and plug the positions and rotations back together.\nrotations_hat = rotations_hat[0].batch(*rotations_hat[1:]).transpose((1, 0, 2))\nxs_hat = sim2real.zip_xs(sys_inference, translations, rotations_hat)\n\n# Create combined system that shall be rendered and its transforms\nsys_render = sys_composer.inject_system(sys, sys_inference.add_prefix_suffix(suffix=\"_hat\"))\nxs_render = x_xy.Transform.concatenate(xs, xs_hat, axis=1)\n
Now we can render both the predicted system (in white) as well as the real system (in orange).
xs_list = [xs_render[i] for i in range(xs_render.shape())]\n\nframes = x_xy.render(sys_render, xs_list, camera=\"targetfar\")\nmediapy.show_video([frame[..., :3] for frame in frames], fps=int(1 / sys.dt))\n
\n
"},{"location":"prism/ss_23_moritz/notebook/#training-the-rnno-with-a-custom-loss-function","title":"Training the RNNo with a custom loss function","text":"This notebook showcases how train an RNNo network with a custom loss function rather than the default mean-reduces angle error. This is showcased by scaling the error by a softmax over the time axis, which puts more weight on the time intervals with a higher deviation compared to ones with lower deviation.
"},{"location":"prism/ss_23_moritz/notebook/#defining-the-systems","title":"Defining the systems","text":"We use two separate systems, both parsed from XML strings: one for training (sys
) and one for inference (dustin_sys
).
Our motion data will be automatically generated using a Generator
, which can be customised using an MotionConfig
. The Generator
will generate data for both q
, that is the state of all the joint angles in the system, as well as xs
, which describes the orientations of all the links in the system. To use this data for training our RNNo, we first have to bring it into the correct form using a finalise_fn
.
To customise the loss function of the RNNo, we transform the error values before they are averaged. The input to our loss function will be both \\(q\\), the real joint state, as well as \\(\\hat{q}\\), the joint space estimated by our RNNo. q
and q_hat
will both be jax.Array
s of shape (T_tbp, 4)
, where the first axis is slice over time (of our TBPTT length) and the second axis are the 4 components of a quaternion.
In this notebook we want to change the relative weightings of the errors at different times using a softmax function in order to put more weight on larger errors. First we convert the errors from quaterions to angles. Then we scale each error angle by a factor, calculated from a softmax over the angles. The calculation of the factors includes a call to jax.lax.stop_gradient
to make it so our gradients are only from the errors themselves, not the factors as well.
Supports Python=3.10/3.11/3.12
(tested).
Install with pip
using
pip install 'ring @ git+https://github.com/SimiPixel/ring'
Typically, this will install jax
as cpu-only version. Afterwards, gpu-enabled version can be installed with
pip install --upgrade \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n
"},{"location":"#documentation","title":"Documentation","text":"Available here.
"},{"location":"#known-fixes","title":"Known fixes","text":""},{"location":"#offscreen-rendering-with-mujoco","title":"Offscreen rendering with Mujoco","text":"mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
Solution:
import os\nos.environ[\"MUJOCO_GL\"] = \"egl\"\n
"},{"location":"#publications","title":"Publications","text":"The following publications utilize this software library, and refer to it as the Random Chain Motion Generator (RCMG) (more specifically the function ring.RCMG
):
Particularly useful is the following publication from Roy Featherstone - A Beginner\u2019s Guide to 6-D Vectors (Part 2)
"},{"location":"#contact","title":"Contact","text":"Simon Bachhuber (simon.bachhuber@fau.de)
"},{"location":"api/","title":"Api","text":""},{"location":"api/#ring.ml.ringnet.RING","title":"RING
","text":"Source code in src/ring/ml/ringnet.py
class RING(ml_base.AbstractFilter):\n def __init__(self, params=None, lam=None, jit: bool = True, name=None, **kwargs):\n self.forward_lam_factory = partial(make_ring, **kwargs)\n self.params = self._load_params(params)\n self.lam = lam\n self._name = name\n\n if jit:\n self.apply = jax.jit(self.apply, static_argnames=\"lam\")\n\n def apply(self, X, params=None, state=None, y=None, lam=None):\n if lam is None:\n assert self.lam is not None\n lam = self.lam\n\n return super().apply(X, params, state, y, tuple(lam))\n\n def init(self, bs: Optional[int] = None, X=None, lam=None, seed: int = 1):\n assert X is not None, \"Providing `X` via in `ringnet.init(X=X)` is required\"\n if bs is not None:\n assert X.ndim == 4\n\n if X.ndim == 4:\n if bs is not None:\n assert bs == X.shape[0]\n else:\n bs = X.shape[0]\n X = X[0]\n\n # (T, N, F) -> (1, N, F) for faster .init call\n X = X[0:1]\n\n if lam is None:\n assert self.lam is not None\n lam = self.lam\n\n key = jax.random.PRNGKey(seed)\n params, state = self.forward_lam_factory(lam=lam).init(key, X)\n\n if bs is not None:\n state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)\n\n return params, state\n\n def _apply_batched(self, X, params, state, y, lam):\n if (params is None and self.params is None) or state is None:\n _params, _state = self.init(bs=X.shape[0], X=X, lam=lam)\n\n if params is None and self.params is None:\n params = _params\n elif params is None:\n params = self.params\n else:\n pass\n\n if state is None:\n state = _state\n\n yhat, next_state = jax.vmap(\n self.forward_lam_factory(lam=lam).apply, in_axes=(None, 0, 0)\n )(params, state, X)\n\n return yhat, next_state\n\n @staticmethod\n def _load_params(params: str | dict | None | Path):\n assert isinstance(params, (str, dict, type(None), Path))\n if isinstance(params, (Path, str)):\n return pickle_load(params)\n return params\n\n def nojit(self) -> \"RING\":\n ringnet = RING(params=self.params, lam=self.lam, jit=False)\n ringnet.forward_lam_factory = self.forward_lam_factory\n return ringnet\n\n def _pre_save(self, params=None, lam=None) -> None:\n if params is not None:\n self.params = params\n if lam is not None:\n self.lam = lam\n\n @staticmethod\n def _post_load(ringnet: \"RING\", jit: bool = True) -> \"RING\":\n if jit:\n ringnet.apply = jax.jit(ringnet.apply, static_argnames=\"lam\")\n return ringnet\n
"},{"location":"api/#ring.ml.RING_ICML24","title":"RING_ICML24(**kwargs)
","text":"Source code in src/ring/ml/__init__.py
def RING_ICML24(**kwargs):\n from pathlib import Path\n\n params = Path(__file__).parent.joinpath(\"params/0x13e3518065c21cd8.pickle\")\n ringnet = RING(params=params, **kwargs) # noqa: F811\n ringnet = base.ScaleX_FilterWrapper(ringnet)\n ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)\n ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)\n return ringnet\n
"},{"location":"api/#ring.base.System","title":"System
","text":"Source code in src/ring/base.py
@struct.dataclass\nclass System(_Base):\n link_parents: list[int] = struct.field(False)\n links: Link\n link_types: list[str] = struct.field(False)\n link_damping: jax.Array\n link_armature: jax.Array\n link_spring_stiffness: jax.Array\n link_spring_zeropoint: jax.Array\n # simulation timestep size\n dt: float = struct.field(False)\n # geometries in the system\n geoms: list[Geometry]\n # root / base acceleration offset\n gravity: jax.Array = struct.field(default_factory=lambda: jnp.array([0, 0, -9.81]))\n\n integration_method: str = struct.field(\n False, default_factory=lambda: \"semi_implicit_euler\"\n )\n mass_mat_iters: int = struct.field(False, default_factory=lambda: 0)\n\n link_names: list[str] = struct.field(False, default_factory=lambda: [])\n\n model_name: Optional[str] = struct.field(False, default_factory=lambda: None)\n\n omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])\n\n def num_links(self) -> int:\n return len(self.link_parents)\n\n def q_size(self) -> int:\n return sum([Q_WIDTHS[typ] for typ in self.link_types])\n\n def qd_size(self) -> int:\n return sum([QD_WIDTHS[typ] for typ in self.link_types])\n\n def name_to_idx(self, name: str) -> int:\n return self.link_names.index(name)\n\n def idx_to_name(self, idx: int, allow_world: bool = False) -> str:\n if allow_world and idx == -1:\n return \"world\"\n assert idx >= 0, \"Worldbody index has no name.\"\n return self.link_names[idx]\n\n def idx_map(self, type: str) -> dict:\n \"type: is either `l` or `q` or `d`\"\n dict_int_slices = {}\n\n def f(_, idx_map, name: str, link_idx: int):\n dict_int_slices[name] = idx_map[type](link_idx)\n\n self.scan(f, \"ll\", self.link_names, list(range(self.num_links())))\n\n return dict_int_slices\n\n def parent_name(self, name: str) -> str:\n return self.idx_to_name(self.link_parents[self.name_to_idx(name)])\n\n def add_prefix(self, prefix: str = \"\") -> \"System\":\n return self.replace(link_names=[prefix + name for name in self.link_names])\n\n def change_model_name(\n self,\n new_name: Optional[str] = None,\n prefix: Optional[str] = None,\n suffix: Optional[str] = None,\n ) -> \"System\":\n if prefix is None:\n prefix = \"\"\n if suffix is None:\n suffix = \"\"\n if new_name is None:\n new_name = self.model_name\n name = prefix + new_name + suffix\n return self.replace(model_name=name)\n\n def change_link_name(self, old_name: str, new_name: str) -> \"System\":\n old_idx = self.name_to_idx(old_name)\n new_link_names = self.link_names.copy()\n new_link_names[old_idx] = new_name\n return self.replace(link_names=new_link_names)\n\n def add_prefix_suffix(\n self, prefix: Optional[str] = None, suffix: Optional[str] = None\n ) -> \"System\":\n if prefix is None:\n prefix = \"\"\n if suffix is None:\n suffix = \"\"\n new_link_names = [prefix + name + suffix for name in self.link_names]\n return self.replace(link_names=new_link_names)\n\n @staticmethod\n def deep_equal(a, b):\n if type(a) is not type(b):\n return False\n if isinstance(a, _Base):\n return System.deep_equal(a.__dict__, b.__dict__)\n if isinstance(a, dict):\n if a.keys() != b.keys():\n return False\n return all(System.deep_equal(a[k], b[k]) for k in a.keys())\n if isinstance(a, (list, tuple)):\n if len(a) != len(b):\n return False\n return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))\n if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):\n return jnp.array_equal(a, b)\n return a == b\n\n def _replace_free_with_cor(self) -> \"System\":\n # check that\n # - all free joints connect to -1\n # - all joints connecting to -1 are free joints\n for i, p in enumerate(self.link_parents):\n link_type = self.link_types[i]\n if (p == -1 and link_type != \"free\") or (link_type == \"free\" and p != -1):\n raise InvalidSystemError(\n f\"link={self.idx_to_name(i)}, parent=\"\n f\"{self.idx_to_name(p, allow_world=True)},\"\n f\" joint={link_type}. Hint: Try setting `config.cor` to false.\"\n )\n\n def logic_replace_free_with_cor(name, olt, ola, old, ols, olz):\n # by default new is equal to old\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n # old link type == free\n if olt == \"free\":\n # cor joint is (free, p3d) stacked\n nlt = \"cor\"\n # entries of old armature are 3*ang (spherical), 3*pos (p3d)\n nla = jnp.concatenate((ola, ola[3:]))\n nld = jnp.concatenate((old, old[3:]))\n nls = jnp.concatenate((ols, ols[3:]))\n nlz = jnp.concatenate((olz, olz[4:]))\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)\n\n def freeze(self, name: str | list[str]):\n if isinstance(name, list):\n sys = self\n for n in name:\n sys = sys.freeze(n)\n return sys\n\n def logic_freeze(link_name, olt, ola, old, ols, olz):\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n if link_name == name:\n nlt = \"frozen\"\n nla = nld = nls = nlz = jnp.array([])\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_freeze)\n\n def unfreeze(self, name: str, new_joint_type: str):\n assert self.link_types[self.name_to_idx(name)] == \"frozen\"\n assert new_joint_type != \"frozen\"\n\n return self.change_joint_type(name, new_joint_type)\n\n def change_joint_type(\n self,\n name: str,\n new_joint_type: str,\n new_arma: Optional[jax.Array] = None,\n new_damp: Optional[jax.Array] = None,\n new_stif: Optional[jax.Array] = None,\n new_zero: Optional[jax.Array] = None,\n ):\n \"By default damping, stiffness are set to zero.\"\n q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]\n\n def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n if link_name == name:\n nlt = new_joint_type\n q_zeros = jnp.zeros((q_size))\n qd_zeros = jnp.zeros((qd_size,))\n\n nla = qd_zeros if new_arma is None else new_arma\n nld = qd_zeros if new_damp is None else new_damp\n nls = qd_zeros if new_stif is None else new_stif\n nlz = q_zeros if new_zero is None else new_zero\n\n # unit quaternion\n if new_joint_type in [\"spherical\", \"free\", \"cor\"] and new_zero is None:\n nlz = nlz.at[0].set(1.0)\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)\n\n def findall_imus(self) -> list[str]:\n return [name for name in self.link_names if name[:3] == \"imu\"]\n\n def findall_segments(self) -> list[str]:\n imus = self.findall_imus()\n return [name for name in self.link_names if name not in imus]\n\n def _bodies_indices_to_bodies_name(self, bodies: list[int]) -> list[str]:\n return [self.idx_to_name(i) for i in bodies]\n\n def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:\n bodies = [i for i, p in enumerate(self.link_parents) if p == -1]\n return self._bodies_indices_to_bodies_name(bodies) if names else bodies\n\n def find_body_to_world(self, name: bool = False) -> int | str:\n bodies = self.findall_bodies_to_world(names=name)\n assert len(bodies) == 1\n return bodies[0]\n\n def findall_bodies_with_jointtype(\n self, typ: str, names: bool = False\n ) -> list[int] | list[str]:\n bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]\n return self._bodies_indices_to_bodies_name(bodies) if names else bodies\n\n def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):\n \"\"\"Scan `f` along each link in system whilst carrying along state.\n\n Args:\n f (Callable[..., Y]): f(y: Y, *args) -> y\n in_types: string specifying the type of each input arg:\n 'l' is an input to be split according to link ranges\n 'q' is an input to be split according to q ranges\n 'd' is an input to be split according to qd ranges\n args: Arguments passed to `f`, and split to match the link.\n reverse (bool, optional): If `true` from leaves to root. Defaults to False.\n\n Returns:\n ys: Stacked output y of f.\n \"\"\"\n return _scan_sys(self, f, in_types, *args, reverse=reverse)\n\n def parse(self) -> \"System\":\n \"\"\"Initial setup of system. System object does not work unless it is parsed.\n Currently it does:\n - some consistency checks\n - populate the spatial inertia tensors\n - check that all names are unique\n - check that names are strings\n - check that all pos_min <= pos_max (unless traced)\n - order geoms in ascending order based on their parent link idx\n - check that all links have the correct size of\n - damping\n - armature\n - stiffness\n - zeropoint\n - check that n_links == len(sys.omc)\n \"\"\"\n return _parse_system(self)\n\n def render(\n self,\n xs: Optional[Transform | list[Transform]] = None,\n camera: Optional[str] = None,\n show_pbar: bool = True,\n backend: str = \"mujoco\",\n render_every_nth: int = 1,\n **scene_kwargs,\n ) -> list[np.ndarray]:\n \"\"\"Render frames from system and trajectory of maximal coordinates `xs`.\n\n Args:\n sys (base.System): System to render.\n xs (base.Transform | list[base.Transform]): Single or time-series\n of maximal coordinates `xs`.\n show_pbar (bool, optional): Whether or not to show a progress bar.\n Defaults to True.\n\n Returns:\n list[np.ndarray]: Stacked rendered frames. Length == len(xs).\n \"\"\"\n return ring.rendering.render(\n self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs\n )\n\n def render_prediction(\n self,\n xs: Transform | list[Transform],\n yhat: dict,\n stepframe: int = 1,\n # by default we don't predict the global rotation\n transparent_segment_to_root: bool = True,\n **kwargs,\n ):\n \"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.\"\n return ring.rendering.render_prediction(\n self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs\n )\n\n def delete_system(self, link_name: str | list[str], strict: bool = True):\n \"Cut subsystem starting at `link_name` (inclusive) from tree.\"\n return ring.sys_composer.delete_subsystem(self, link_name, strict)\n\n def make_sys_noimu(self, imu_link_names: Optional[list[str]] = None):\n \"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}\"\n return ring.sys_composer.make_sys_noimu(self, imu_link_names)\n\n def inject_system(self, other_system: \"System\", at_body: Optional[str] = None):\n \"\"\"Combine two systems into one.\n\n Args:\n sys (base.System): Large system.\n sub_sys (base.System): Small system that will be included into the\n large system `sys`.\n at_body (Optional[str], optional): Into which body of the large system\n small system will be included. Defaults to `worldbody`.\n\n Returns:\n base.System: _description_\n \"\"\"\n return ring.sys_composer.inject_system(self, other_system, at_body)\n\n def morph_system(\n self,\n new_parents: Optional[list[int | str]] = None,\n new_anchor: Optional[int | str] = None,\n ):\n \"\"\"Re-orders the graph underlying the system. Returns a new system.\n\n Args:\n sys (base.System): System to be modified.\n new_parents (list[int]): Let the i-th entry have value j. Then, after\n morphing the system the system will be such that the link corresponding\n to the i-th link in the old system will have as parent the link\n corresponding to the j-th link in the old system.\n\n Returns:\n base.System: Modified system.\n \"\"\"\n return ring.sys_composer.morph_system(self, new_parents, new_anchor)\n\n @staticmethod\n def from_xml(path: str, seed: int = 1):\n return ring.io.load_sys_from_xml(path, seed)\n\n @staticmethod\n def from_str(xml: str, seed: int = 1):\n return ring.io.load_sys_from_str(xml, seed)\n\n def to_str(self) -> str:\n return ring.io.save_sys_to_str(self)\n\n def to_xml(self, path: str) -> None:\n ring.io.save_sys_to_xml(self, path)\n\n @classmethod\n def create(cls, path_or_str: str, seed: int = 1) -> \"System\":\n path = Path(path_or_str).with_suffix(\".xml\")\n if path.exists():\n return cls.from_xml(path, seed=seed)\n else:\n return cls.from_str(path_or_str)\n\n def coordinate_vector_to_q(\n self,\n q: jax.Array,\n custom_joints: dict[str, Callable] = {},\n ) -> jax.Array:\n \"\"\"Map a coordinate vector `q` to the minimal coordinates vector of the sys\"\"\"\n # Does, e.g.\n # - normalize quaternions\n # - hinge joints in [-pi, pi]\n q_preproc = []\n\n def preprocess(_, __, link_type, q):\n to_q = ring.algorithms.jcalc.get_joint_model(\n link_type\n ).coordinate_vector_to_q\n # function in custom_joints has priority over JointModel\n if link_type in custom_joints:\n to_q = custom_joints[link_type]\n if to_q is None:\n raise NotImplementedError(\n f\"Please specify the custom joint `{link_type}`\"\n \" either using the `custom_joints` arguments or using the\"\n \" JointModel.coordinate_vector_to_q field.\"\n )\n new_q = to_q(q)\n q_preproc.append(new_q)\n\n self.scan(preprocess, \"lq\", self.link_types, q)\n return jnp.concatenate(q_preproc)\n
"},{"location":"api/#ring.base.System.idx_map","title":"idx_map(type)
","text":"type: is either l
or q
or d
src/ring/base.py
def idx_map(self, type: str) -> dict:\n \"type: is either `l` or `q` or `d`\"\n dict_int_slices = {}\n\n def f(_, idx_map, name: str, link_idx: int):\n dict_int_slices[name] = idx_map[type](link_idx)\n\n self.scan(f, \"ll\", self.link_names, list(range(self.num_links())))\n\n return dict_int_slices\n
"},{"location":"api/#ring.base.System.change_joint_type","title":"change_joint_type(name, new_joint_type, new_arma=None, new_damp=None, new_stif=None, new_zero=None)
","text":"By default damping, stiffness are set to zero.
Source code insrc/ring/base.py
def change_joint_type(\n self,\n name: str,\n new_joint_type: str,\n new_arma: Optional[jax.Array] = None,\n new_damp: Optional[jax.Array] = None,\n new_stif: Optional[jax.Array] = None,\n new_zero: Optional[jax.Array] = None,\n):\n \"By default damping, stiffness are set to zero.\"\n q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]\n\n def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):\n nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz\n\n if link_name == name:\n nlt = new_joint_type\n q_zeros = jnp.zeros((q_size))\n qd_zeros = jnp.zeros((qd_size,))\n\n nla = qd_zeros if new_arma is None else new_arma\n nld = qd_zeros if new_damp is None else new_damp\n nls = qd_zeros if new_stif is None else new_stif\n nlz = q_zeros if new_zero is None else new_zero\n\n # unit quaternion\n if new_joint_type in [\"spherical\", \"free\", \"cor\"] and new_zero is None:\n nlz = nlz.at[0].set(1.0)\n\n return nlt, nla, nld, nls, nlz\n\n return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)\n
"},{"location":"api/#ring.base.System.scan","title":"scan(f, in_types, *args, reverse=False)
","text":"Scan f
along each link in system whilst carrying along state.
Parameters:
Name Type Description Defaultf
Callable[..., Y]
f(y: Y, *args) -> y
requiredin_types
str
string specifying the type of each input arg: 'l' is an input to be split according to link ranges 'q' is an input to be split according to q ranges 'd' is an input to be split according to qd ranges
requiredargs
Arguments passed to f
, and split to match the link.
()
reverse
bool
If true
from leaves to root. Defaults to False.
False
Returns:
Name Type Descriptionys
Stacked output y of f.
Source code insrc/ring/base.py
def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):\n \"\"\"Scan `f` along each link in system whilst carrying along state.\n\n Args:\n f (Callable[..., Y]): f(y: Y, *args) -> y\n in_types: string specifying the type of each input arg:\n 'l' is an input to be split according to link ranges\n 'q' is an input to be split according to q ranges\n 'd' is an input to be split according to qd ranges\n args: Arguments passed to `f`, and split to match the link.\n reverse (bool, optional): If `true` from leaves to root. Defaults to False.\n\n Returns:\n ys: Stacked output y of f.\n \"\"\"\n return _scan_sys(self, f, in_types, *args, reverse=reverse)\n
"},{"location":"api/#ring.base.System.parse","title":"parse()
","text":"Initial setup of system. System object does not work unless it is parsed. Currently it does: - some consistency checks - populate the spatial inertia tensors - check that all names are unique - check that names are strings - check that all pos_min <= pos_max (unless traced) - order geoms in ascending order based on their parent link idx - check that all links have the correct size of - damping - armature - stiffness - zeropoint - check that n_links == len(sys.omc)
Source code insrc/ring/base.py
def parse(self) -> \"System\":\n \"\"\"Initial setup of system. System object does not work unless it is parsed.\n Currently it does:\n - some consistency checks\n - populate the spatial inertia tensors\n - check that all names are unique\n - check that names are strings\n - check that all pos_min <= pos_max (unless traced)\n - order geoms in ascending order based on their parent link idx\n - check that all links have the correct size of\n - damping\n - armature\n - stiffness\n - zeropoint\n - check that n_links == len(sys.omc)\n \"\"\"\n return _parse_system(self)\n
"},{"location":"api/#ring.base.System.render","title":"render(xs=None, camera=None, show_pbar=True, backend='mujoco', render_every_nth=1, **scene_kwargs)
","text":"Render frames from system and trajectory of maximal coordinates xs
.
Parameters:
Name Type Description Defaultsys
System
System to render.
requiredxs
Transform | list[Transform]
Single or time-series
None
show_pbar
bool
Whether or not to show a progress bar.
True
Returns:
Type Descriptionlist[ndarray]
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
Source code insrc/ring/base.py
def render(\n self,\n xs: Optional[Transform | list[Transform]] = None,\n camera: Optional[str] = None,\n show_pbar: bool = True,\n backend: str = \"mujoco\",\n render_every_nth: int = 1,\n **scene_kwargs,\n) -> list[np.ndarray]:\n \"\"\"Render frames from system and trajectory of maximal coordinates `xs`.\n\n Args:\n sys (base.System): System to render.\n xs (base.Transform | list[base.Transform]): Single or time-series\n of maximal coordinates `xs`.\n show_pbar (bool, optional): Whether or not to show a progress bar.\n Defaults to True.\n\n Returns:\n list[np.ndarray]: Stacked rendered frames. Length == len(xs).\n \"\"\"\n return ring.rendering.render(\n self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs\n )\n
"},{"location":"api/#ring.base.System.render_prediction","title":"render_prediction(xs, yhat, stepframe=1, transparent_segment_to_root=True, **kwargs)
","text":"xs
matches sys
. yhat
matches sys_noimu
. yhat
are child-to-parent.
src/ring/base.py
def render_prediction(\n self,\n xs: Transform | list[Transform],\n yhat: dict,\n stepframe: int = 1,\n # by default we don't predict the global rotation\n transparent_segment_to_root: bool = True,\n **kwargs,\n):\n \"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.\"\n return ring.rendering.render_prediction(\n self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs\n )\n
"},{"location":"api/#ring.base.System.delete_system","title":"delete_system(link_name, strict=True)
","text":"Cut subsystem starting at link_name
(inclusive) from tree.
src/ring/base.py
def delete_system(self, link_name: str | list[str], strict: bool = True):\n \"Cut subsystem starting at `link_name` (inclusive) from tree.\"\n return ring.sys_composer.delete_subsystem(self, link_name, strict)\n
"},{"location":"api/#ring.base.System.make_sys_noimu","title":"make_sys_noimu(imu_link_names=None)
","text":"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}
Source code insrc/ring/base.py
def make_sys_noimu(self, imu_link_names: Optional[list[str]] = None):\n \"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}\"\n return ring.sys_composer.make_sys_noimu(self, imu_link_names)\n
"},{"location":"api/#ring.base.System.inject_system","title":"inject_system(other_system, at_body=None)
","text":"Combine two systems into one.
Parameters:
Name Type Description Defaultsys
System
Large system.
requiredsub_sys
System
Small system that will be included into the large system sys
.
at_body
Optional[str]
Into which body of the large system small system will be included. Defaults to worldbody
.
None
Returns:
Type Descriptionbase.System: description
Source code insrc/ring/base.py
def inject_system(self, other_system: \"System\", at_body: Optional[str] = None):\n \"\"\"Combine two systems into one.\n\n Args:\n sys (base.System): Large system.\n sub_sys (base.System): Small system that will be included into the\n large system `sys`.\n at_body (Optional[str], optional): Into which body of the large system\n small system will be included. Defaults to `worldbody`.\n\n Returns:\n base.System: _description_\n \"\"\"\n return ring.sys_composer.inject_system(self, other_system, at_body)\n
"},{"location":"api/#ring.base.System.morph_system","title":"morph_system(new_parents=None, new_anchor=None)
","text":"Re-orders the graph underlying the system. Returns a new system.
Parameters:
Name Type Description Defaultsys
System
System to be modified.
requirednew_parents
list[int]
Let the i-th entry have value j. Then, after morphing the system the system will be such that the link corresponding to the i-th link in the old system will have as parent the link corresponding to the j-th link in the old system.
None
Returns:
Type Descriptionbase.System: Modified system.
Source code insrc/ring/base.py
def morph_system(\n self,\n new_parents: Optional[list[int | str]] = None,\n new_anchor: Optional[int | str] = None,\n):\n \"\"\"Re-orders the graph underlying the system. Returns a new system.\n\n Args:\n sys (base.System): System to be modified.\n new_parents (list[int]): Let the i-th entry have value j. Then, after\n morphing the system the system will be such that the link corresponding\n to the i-th link in the old system will have as parent the link\n corresponding to the j-th link in the old system.\n\n Returns:\n base.System: Modified system.\n \"\"\"\n return ring.sys_composer.morph_system(self, new_parents, new_anchor)\n
"},{"location":"api/#ring.base.System.coordinate_vector_to_q","title":"coordinate_vector_to_q(q, custom_joints={})
","text":"Map a coordinate vector q
to the minimal coordinates vector of the sys
src/ring/base.py
def coordinate_vector_to_q(\n self,\n q: jax.Array,\n custom_joints: dict[str, Callable] = {},\n) -> jax.Array:\n \"\"\"Map a coordinate vector `q` to the minimal coordinates vector of the sys\"\"\"\n # Does, e.g.\n # - normalize quaternions\n # - hinge joints in [-pi, pi]\n q_preproc = []\n\n def preprocess(_, __, link_type, q):\n to_q = ring.algorithms.jcalc.get_joint_model(\n link_type\n ).coordinate_vector_to_q\n # function in custom_joints has priority over JointModel\n if link_type in custom_joints:\n to_q = custom_joints[link_type]\n if to_q is None:\n raise NotImplementedError(\n f\"Please specify the custom joint `{link_type}`\"\n \" either using the `custom_joints` arguments or using the\"\n \" JointModel.coordinate_vector_to_q field.\"\n )\n new_q = to_q(q)\n q_preproc.append(new_q)\n\n self.scan(preprocess, \"lq\", self.link_types, q)\n return jnp.concatenate(q_preproc)\n
"},{"location":"api/#ring.base.State","title":"State
","text":"The static and dynamic state of a system in minimal and maximal coordinates. Use .create()
to create this object.
Parameters:
Name Type Description Defaultq
Array
System state in minimal coordinates (equals sys.q_size()
)
qd
Array
System velocity in minimal coordinates (equals sys.qd_size()
)
x
(Transform): Maximal coordinates of all links. From epsilon-to-link.
requiredmass_mat_inv
Array
Inverse of the mass matrix. Internal usage.
required Source code insrc/ring/base.py
@struct.dataclass\nclass State(_Base):\n \"\"\"The static and dynamic state of a system in minimal and maximal coordinates.\n Use `.create()` to create this object.\n\n Args:\n q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)\n qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)\n x: (Transform): Maximal coordinates of all links. From epsilon-to-link.\n mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.\n \"\"\"\n\n q: jax.Array\n qd: jax.Array\n x: Transform\n mass_mat_inv: jax.Array\n\n @classmethod\n def create(\n cls,\n sys: System,\n q: Optional[jax.Array] = None,\n qd: Optional[jax.Array] = None,\n x: Optional[Transform] = None,\n key: Optional[jax.Array] = None,\n custom_joints: dict[str, Callable] = {},\n ):\n \"\"\"Create state of system.\n\n Args:\n sys (System): The system for which to create a state.\n q (jax.Array, optional): The joint values of the system. Defaults to None.\n Which then defaults to zeros.\n qd (jax.Array, optional): The joint velocities of the system.\n Defaults to None. Which then defaults to zeros.\n\n Returns:\n (State): Create State object.\n \"\"\"\n if key is not None:\n assert q is None\n q = jax.random.normal(key, shape=(sys.q_size(),))\n q = sys.coordinate_vector_to_q(q, custom_joints)\n elif q is None:\n q = jnp.zeros((sys.q_size(),))\n\n # free, cor, spherical joints are not zeros but have unit quaternions\n def replace_by_unit_quat(_, idx_map, link_typ, link_idx):\n nonlocal q\n\n if link_typ in [\"free\", \"cor\", \"spherical\"]:\n q_idxs_link = idx_map[\"q\"](link_idx)\n q = q.at[q_idxs_link.start].set(1.0)\n\n sys.scan(\n replace_by_unit_quat,\n \"ll\",\n sys.link_types,\n list(range(sys.num_links())),\n )\n else:\n pass\n\n if qd is None:\n qd = jnp.zeros((sys.qd_size(),))\n\n if x is None:\n x = Transform.zero((sys.num_links(),))\n\n return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))\n
"},{"location":"api/#ring.base.State.create","title":"create(sys, q=None, qd=None, x=None, key=None, custom_joints={})
classmethod
","text":"Create state of system.
Parameters:
Name Type Description Defaultsys
System
The system for which to create a state.
requiredq
Array
The joint values of the system. Defaults to None.
None
qd
Array
The joint velocities of the system.
None
Returns:
Type DescriptionState
Create State object.
Source code insrc/ring/base.py
@classmethod\ndef create(\n cls,\n sys: System,\n q: Optional[jax.Array] = None,\n qd: Optional[jax.Array] = None,\n x: Optional[Transform] = None,\n key: Optional[jax.Array] = None,\n custom_joints: dict[str, Callable] = {},\n):\n \"\"\"Create state of system.\n\n Args:\n sys (System): The system for which to create a state.\n q (jax.Array, optional): The joint values of the system. Defaults to None.\n Which then defaults to zeros.\n qd (jax.Array, optional): The joint velocities of the system.\n Defaults to None. Which then defaults to zeros.\n\n Returns:\n (State): Create State object.\n \"\"\"\n if key is not None:\n assert q is None\n q = jax.random.normal(key, shape=(sys.q_size(),))\n q = sys.coordinate_vector_to_q(q, custom_joints)\n elif q is None:\n q = jnp.zeros((sys.q_size(),))\n\n # free, cor, spherical joints are not zeros but have unit quaternions\n def replace_by_unit_quat(_, idx_map, link_typ, link_idx):\n nonlocal q\n\n if link_typ in [\"free\", \"cor\", \"spherical\"]:\n q_idxs_link = idx_map[\"q\"](link_idx)\n q = q.at[q_idxs_link.start].set(1.0)\n\n sys.scan(\n replace_by_unit_quat,\n \"ll\",\n sys.link_types,\n list(range(sys.num_links())),\n )\n else:\n pass\n\n if qd is None:\n qd = jnp.zeros((sys.qd_size(),))\n\n if x is None:\n x = Transform.zero((sys.num_links(),))\n\n return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))\n
"},{"location":"api/#ring.algorithms.dynamics.step","title":"step(sys, state, taus=None, n_substeps=1)
","text":"Source code in src/ring/algorithms/dynamics.py
def step(\n sys: base.System,\n state: base.State,\n taus: Optional[jax.Array] = None,\n n_substeps: int = 1,\n) -> base.State:\n assert sys.q_size() == state.q.size\n if taus is None:\n taus = jnp.zeros_like(state.qd)\n assert sys.qd_size() == state.qd.size == taus.size\n assert (\n sys.integration_method.lower() == \"semi_implicit_euler\"\n ), \"Currently, nothing else then `semi_implicit_euler` implemented.\"\n\n sys = sys.replace(dt=sys.dt / n_substeps)\n\n for _ in range(n_substeps):\n # update kinematics before stepping; this means that the `x` in `state`\n # will lag one step behind but otherwise we would have to return\n # the system object which would be awkward\n sys, state = kinematics.forward_kinematics(sys, state)\n state = _integration_methods[sys.integration_method.lower()](sys, state, taus)\n\n return state\n
"},{"location":"api/#ring.base.Transform","title":"Transform
","text":"Represents the Transformation from Pl\u00fccker A to Pl\u00fccker B, where B is located relative to A at pos
in frame A and rot
is the relative quaternion from A to B.
src/ring/base.py
@struct.dataclass\nclass Transform(_Base):\n \"\"\"Represents the Transformation from Pl\u00fccker A to Pl\u00fccker B,\n where B is located relative to A at `pos` in frame A and `rot` is the\n relative quaternion from A to B.\"\"\"\n\n pos: Vector\n rot: Quaternion\n\n @classmethod\n def create(cls, pos=None, rot=None):\n assert not (pos is None and rot is None), \"One must be given.\"\n shape_rot = rot.shape[:-1] if rot is not None else ()\n shape_pos = pos.shape[:-1] if pos is not None else ()\n\n if pos is None:\n pos = jnp.zeros(shape_rot + (3,))\n if rot is None:\n rot = jnp.array([1.0, 0, 0, 0])\n rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape_pos + (1,))\n\n assert pos.shape[:-1] == rot.shape[:-1]\n\n return Transform(pos, rot)\n\n @classmethod\n def zero(cls, shape=()) -> \"Transform\":\n \"\"\"Returns a zero transform with a batch shape.\"\"\"\n pos = jnp.zeros(shape + (3,))\n rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))\n return Transform(pos, rot)\n\n def as_matrix(self) -> jax.Array:\n E = maths.quat_to_3x3(self.rot)\n return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)\n
"},{"location":"api/#ring.base.Transform.zero","title":"zero(shape=())
classmethod
","text":"Returns a zero transform with a batch shape.
Source code insrc/ring/base.py
@classmethod\ndef zero(cls, shape=()) -> \"Transform\":\n \"\"\"Returns a zero transform with a batch shape.\"\"\"\n pos = jnp.zeros(shape + (3,))\n rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))\n return Transform(pos, rot)\n
"},{"location":"api/#ring.algorithms.generator.base.RCMG","title":"RCMG
","text":"Source code in src/ring/algorithms/generator/base.py
class RCMG:\n def __init__(\n self,\n sys: base.System | list[base.System],\n config: jcalc.MotionConfig | list[jcalc.MotionConfig] = jcalc.MotionConfig(),\n setup_fn: Optional[types.SETUP_FN] = None,\n finalize_fn: Optional[types.FINALIZE_FN] = None,\n add_X_imus: bool = False,\n add_X_imus_kwargs: Optional[dict] = None,\n add_X_jointaxes: bool = False,\n add_X_jointaxes_kwargs: Optional[dict] = None,\n add_y_relpose: bool = False,\n add_y_rootincl: bool = False,\n sys_ml: Optional[base.System] = None,\n randomize_positions: bool = False,\n randomize_motion_artifacts: bool = False,\n randomize_joint_params: bool = False,\n randomize_anchors: bool = False,\n randomize_anchors_kwargs: Optional[dict] = None,\n randomize_hz: bool = False,\n randomize_hz_kwargs: Optional[dict] = None,\n imu_motion_artifacts: bool = False,\n imu_motion_artifacts_kwargs: Optional[dict] = None,\n dynamic_simulation: bool = False,\n dynamic_simulation_kwargs: Optional[dict] = None,\n output_transform: Optional[Callable] = None,\n keep_output_extras: bool = False,\n use_link_number_in_Xy: bool = False,\n ) -> None:\n\n randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)\n randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)\n\n if randomize_hz:\n finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)\n\n partial_build_gen = partial(\n _build_generator_lazy,\n setup_fn=setup_fn,\n finalize_fn=finalize_fn,\n add_X_imus=add_X_imus,\n add_X_imus_kwargs=add_X_imus_kwargs,\n add_X_jointaxes=add_X_jointaxes,\n add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,\n add_y_relpose=add_y_relpose,\n add_y_rootincl=add_y_rootincl,\n randomize_positions=randomize_positions,\n randomize_motion_artifacts=randomize_motion_artifacts,\n randomize_joint_params=randomize_joint_params,\n imu_motion_artifacts=imu_motion_artifacts,\n imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,\n dynamic_simulation=dynamic_simulation,\n dynamic_simulation_kwargs=dynamic_simulation_kwargs,\n output_transform=output_transform,\n keep_output_extras=keep_output_extras,\n use_link_number_in_Xy=use_link_number_in_Xy,\n )\n\n sys, config = utils.to_list(sys), utils.to_list(config)\n\n if randomize_anchors:\n assert (\n len(sys) == 1\n ), \"If `randomize_anchors`, then only one system is expected\"\n sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)\n\n zip_sys_config = False\n if randomize_hz:\n zip_sys_config = True\n sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)\n\n if sys_ml is None:\n # TODO\n if False and len(sys) > 1:\n warnings.warn(\n \"Batched simulation with multiple systems but no explicit `sys_ml`\"\n )\n sys_ml = sys[0]\n\n self.gens = []\n if zip_sys_config:\n for _sys, _config in zip(sys, config):\n self.gens.append(\n partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)\n )\n else:\n for _sys in sys:\n for _config in config:\n self.gens.append(\n partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)\n )\n\n def _to_data(self, sizes, seed, jit):\n return batch.batch_generators_eager_to_list(\n self.gens, sizes, seed=seed, jit=jit\n )\n\n def to_list(self, sizes: int | list[int] = 1, seed: int = 1, jit: bool = False):\n return self._to_data(sizes, seed, jit)\n\n def to_pickle(\n self,\n path: str,\n sizes: int | list[int] = 1,\n seed: int = 1,\n jit: bool = False,\n overwrite: bool = True,\n ) -> None:\n data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))\n utils.pickle_save(data, path, overwrite=overwrite)\n\n def to_hdf5(\n self,\n path: str,\n sizes: int | list[int] = 1,\n seed: int = 1,\n jit: bool = False,\n overwrite: bool = True,\n ) -> None:\n data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))\n utils.hdf5_save(path, data, overwrite=overwrite)\n\n def to_eager_gen(\n self,\n batchsize: int = 1,\n sizes: int | list[int] = 1,\n seed: int = 1,\n jit: bool = False,\n ) -> types.BatchedGenerator:\n return batch.batch_generators_eager(\n self.gens, sizes, batchsize, seed=seed, jit=jit\n )\n\n def to_lazy_gen(\n self, sizes: int | list[int] = 1, jit: bool = True\n ) -> types.BatchedGenerator:\n return batch.batch_generators_lazy(self.gens, sizes, jit=jit)\n\n @staticmethod\n def eager_gen_from_paths(\n paths: str | list[str],\n batchsize: int,\n include_samples: Optional[list[int]] = None,\n shuffle: bool = True,\n load_all_into_memory: bool = False,\n tree_transform=None,\n ) -> tuple[types.BatchedGenerator, int]:\n paths = utils.to_list(paths)\n return batch.batched_generator_from_paths(\n paths,\n batchsize,\n include_samples,\n shuffle,\n load_all_into_memory=load_all_into_memory,\n tree_transform=tree_transform,\n )\n
"},{"location":"api/#ring.algorithms.jcalc.MotionConfig","title":"MotionConfig
dataclass
","text":"Source code in src/ring/algorithms/jcalc.py
@dataclass\nclass MotionConfig:\n T: float = 60.0 # length of random motion\n t_min: float = 0.05 # min time between two generated angles\n t_max: float | TimeDependentFloat = 0.30 # max time ..\n\n dang_min: float | TimeDependentFloat = 0.1 # minimum angular velocity in rad/s\n dang_max: float | TimeDependentFloat = 3.0 # maximum angular velocity in rad/s\n\n # minimum angular velocity of euler angles used for `free and spherical joints`\n dang_min_free_spherical: float | TimeDependentFloat = 0.1\n dang_max_free_spherical: float | TimeDependentFloat = 3.0\n\n # max min allowed actual delta values in radians\n delta_ang_min: float | TimeDependentFloat = 0.0\n delta_ang_max: float | TimeDependentFloat = 2 * jnp.pi\n delta_ang_min_free_spherical: float | TimeDependentFloat = 0.0\n delta_ang_max_free_spherical: float | TimeDependentFloat = 2 * jnp.pi\n\n dpos_min: float | TimeDependentFloat = 0.001 # speed of translation\n dpos_max: float | TimeDependentFloat = 0.7\n pos_min: float | TimeDependentFloat = -2.5\n pos_max: float | TimeDependentFloat = +2.5\n\n # used by both `random_angle_*` and `random_pos_*`\n # only used if `randomized_interpolation` is set\n cdf_bins_min: int = 5\n # by default equal to `cdf_bins_min`\n cdf_bins_max: Optional[int] = None\n\n # flags\n randomized_interpolation_angle: bool = False\n randomized_interpolation_position: bool = False\n interpolation_method: str = \"cosine\"\n range_of_motion_hinge: bool = True\n range_of_motion_hinge_method: str = \"uniform\"\n\n # initial value of joints\n ang0_min: float = -jnp.pi\n ang0_max: float = jnp.pi\n pos0_min: float = 0.0\n pos0_max: float = 0.0\n\n # cor (center of rotation) custom fields\n cor: bool = False\n cor_t_min: float = 0.2\n cor_t_max: float | TimeDependentFloat = 2.0\n cor_dpos_min: float | TimeDependentFloat = 0.00001\n cor_dpos_max: float | TimeDependentFloat = 0.5\n cor_pos_min: float | TimeDependentFloat = -0.4\n cor_pos_max: float | TimeDependentFloat = 0.4\n\n def is_feasible(self) -> bool:\n return _is_feasible_config1(self)\n\n def to_nomotion_config(self) -> \"MotionConfig\":\n kwargs = asdict(self)\n for key in [\n \"dang_min\",\n \"dang_max\",\n \"delta_ang_min\",\n \"dang_min_free_spherical\",\n \"dang_max_free_spherical\",\n \"delta_ang_min_free_spherical\",\n \"dpos_min\",\n \"dpos_max\",\n ]:\n kwargs[key] = 0.0\n nomotion_config = MotionConfig(**kwargs)\n assert nomotion_config.is_feasible()\n return nomotion_config\n
"},{"location":"api/#ring.algorithms.jcalc.register_new_joint_type","title":"register_new_joint_type(joint_type, joint_model, q_width, qd_width=None, overwrite=False)
","text":"Source code in src/ring/algorithms/jcalc.py
def register_new_joint_type(\n joint_type: str,\n joint_model: JointModel,\n q_width: int,\n qd_width: Optional[int] = None,\n overwrite: bool = False,\n):\n # this name is used\n assert joint_type != \"default\", \"Please use another name.\"\n\n exists = joint_type in _joint_types\n if exists and overwrite:\n for dic in [\n base.Q_WIDTHS,\n base.QD_WIDTHS,\n _joint_types,\n ]:\n dic.pop(joint_type)\n else:\n assert (\n not exists\n ), f\"joint type `{joint_type}`already exists, use `overwrite=True`\"\n\n if qd_width is None:\n qd_width = q_width\n\n assert len(joint_model.motion) == qd_width\n\n _joint_types.update({joint_type: joint_model})\n base.Q_WIDTHS.update({joint_type: q_width})\n base.QD_WIDTHS.update({joint_type: qd_width})\n
"},{"location":"api/#ring.algorithms.jcalc.JointModel","title":"JointModel
dataclass
","text":"Source code in src/ring/algorithms/jcalc.py
@dataclass\nclass JointModel:\n # (q, params) -> Transform\n transform: Callable[[jax.Array, jax.Array], base.Transform]\n # len(motion) == len(qd)\n # if callable: joint_params -> base.Motion\n motion: list[base.Motion | Callable[[jax.Array], base.Motion]] = field(\n default_factory=lambda: []\n )\n # (config, key_t, key_value, params) -> jax.Array\n rcmg_draw_fn: Optional[DRAW_FN] = None\n\n # only used by `pd_control`\n p_control_term: Optional[P_CONTROL_TERM] = None\n qd_from_q: Optional[QD_FROM_Q] = None\n\n # used by\n # -`inverse_kinematics_endeffector`\n # - System.coordinate_vector_to_q\n coordinate_vector_to_q: Optional[COORDINATE_VECTOR_TO_Q] = None\n\n # only used by `inverse_kinematics`\n inv_kin: Optional[INV_KIN] = None\n\n init_joint_params: Optional[INIT_JOINT_PARAMS] = None\n\n utilities: Optional[dict[str, Any]] = field(default_factory=lambda: dict())\n
"},{"location":"api/#ring.algorithms.jcalc.join_motionconfigs","title":"join_motionconfigs(configs, boundaries)
","text":"Source code in src/ring/algorithms/jcalc.py
def join_motionconfigs(\n configs: list[MotionConfig], boundaries: list[float]\n) -> MotionConfig:\n assert len(configs) == (\n len(boundaries) + 1\n ), \"length of `boundaries` should be one less than length of `configs`\"\n boundaries = jnp.array(boundaries, dtype=float)\n\n def new_value(field: str):\n scalar_options = jnp.array([getattr(c, field) for c in configs])\n\n def scalar(t):\n return jax.lax.dynamic_index_in_dim(\n scalar_options, _find_interval(t, boundaries), keepdims=False\n )\n\n return scalar\n\n hints = get_type_hints(MotionConfig())\n attrs = MotionConfig().__dict__\n is_time_dependent_field = lambda key: hints[key] == (float | TimeDependentFloat)\n time_dependent_fields = [key for key in attrs if is_time_dependent_field(key)]\n time_independent_fields = [key for key in attrs if not is_time_dependent_field(key)]\n\n for time_dep_field in time_independent_fields:\n field_values = set([getattr(config, time_dep_field) for config in configs])\n assert (\n len(field_values) == 1\n ), f\"MotionConfig.{time_dep_field}={field_values}. Should be one unique value..\"\n\n changes = {field: new_value(field) for field in time_dependent_fields}\n return replace(configs[0], **changes)\n
"},{"location":"notebooks/batched_simulation/","title":"Batched simulation","text":"Note
This example is available as a jupyter notebook here.
System
object is a registered Jax-PyTree. This means it's a nested array.
This enables us to stack multiple systems (or states) to enable vectorized operations.
import ring\n\nimport jax\nimport jax.numpy as jnp\n\n\nxml_str = \"\"\"\n<x_xy model=\"double_pendulum\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"2\" euler=\"0 90 0\" joint=\"ry\" name=\"upper\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body damping=\"2\" joint=\"ry\" name=\"lower\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\nstate = ring.State.create(sys)\n
# second system with gravity disabled\nsys_nograv = sys.replace(gravity = sys.gravity * 0.0)\nsys_batched = sys.batch(sys_nograv)\n\nnext_state_batched = jax.vmap(ring.step, in_axes=(0, None))(sys_batched, state)\n
# note how the state of the system without gravity has not changed at all\nnext_state_batched.q\n
\nArray([[-1.7982468e-10, 2.3305433e-10],\n [ 0.0000000e+00, 0.0000000e+00]], dtype=float32)
\n
second_state = ring.State.create(sys, qd=jnp.ones((2,)))\nstate_batched = state.batch(second_state)\nnext_state_batched = jax.vmap(ring.step, in_axes=(None, 0))(sys, state_batched)\n
next_state_batched.q\n
\nArray([[-1.7982468e-10, 2.3305433e-10],\n [ 1.0048340e-02, 9.8215193e-03]], dtype=float32)
\n
Batched kinematic simulation is done by providing the sizes
argument to build_generator
batchsize = 8\nseed = 1\ngen = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), keep_output_extras=True).to_lazy_gen(batchsize)\n(X, y), (_, q, x, _) = gen(jax.random.PRNGKey(seed))\n
q.shape\n
\n(8, 1000, 2)
\n
\n
"},{"location":"notebooks/batched_simulation/#batched-dynamical-simulation","title":"Batched Dynamical Simulation","text":""},{"location":"notebooks/batched_simulation/#batched-system","title":"Batched System","text":"I.e. simulating two different system with the same initial state.
"},{"location":"notebooks/batched_simulation/#batched-state","title":"Batched State","text":""},{"location":"notebooks/batched_simulation/#batched-kinematic-simulation","title":"Batched Kinematic Simulation","text":""},{"location":"notebooks/control/","title":"Control","text":"Note
This example is available as a jupyter notebook here.
import ring\n\nfrom ring.algorithms.generator.pd_control import _pd_control\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nimport mediapy as media\n
The step
function also takes generalized forces tau
applied to the degrees of freedom its third input step(sys, state, taus)
.
Let's consider an inverted pendulum on a cart, and apply a left-right force onto the cart such that the pole stays in the upright position.
xml_str = \"\"\"\n<x_xy model=\"inv_pendulum\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<defaults>\n<geom color=\"white\" edge_color=\"black\"></geom>\n</defaults>\n<worldbody>\n<body damping=\"0.01\" joint=\"px\" name=\"cart\">\n<geom dim=\"0.4 0.1 0.1\" mass=\"1\" type=\"box\"></geom>\n<body damping=\"0.01\" euler=\"0 -90 0\" joint=\"ry\" name=\"pendulum\">\n<geom dim=\"1 0.1 0.1\" mass=\"0.5\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\nstate = ring.State.create(sys, q=jnp.array([0.0, 0.2])) \n\nxs = []\nT = 10.0\nfor t in range(int(T / sys.dt)):\n measurement_noise = np.random.normal() * 5\n phi = jnp.rad2deg(state.q[1]) + measurement_noise\n cart_motor_input = 0.1 * phi * abs(phi)\n taus = jnp.clip(jnp.array([cart_motor_input, 0.0]), -10, 10) \n state = jax.jit(ring.step)(sys, state, taus)\n xs.append(state.x)\n
def show_video(sys, xs: list[ring.Transform]):\n assert sys.dt == 0.01\n # only render every fourth to get a framerate of 25 fps\n frames = sys.render(xs, render_every_nth=4, camera=\"c\", add_cameras={-1: '<camera mode=\"targetbody\" name=\"c\" pos=\"0 -2 2\" target=\"0\"></camera>'})\n media.show_video(frames, fps=25)\n\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250/250 [00:01<00:00, 174.21it/s]\n
\n
This browser does not support the video tag. xml_str = \"\"\"\n<x_xy>\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"0.01\" euler=\"0 90 0\" joint=\"ry\" name=\"pendulum\" pos=\"0 0 1\">\n<geom dim=\"1 0.1 0.1\" mass=\"0.5\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\nP, D = jnp.array([10.0]), jnp.array([1.0])\n\ndef simulate_pd_control(sys, P, D):\n controller = _pd_control(P, D)\n # reference signal\n q_ref = jnp.ones((1000, 1)) * jnp.pi / 2\n controller_state = controller.init(sys, q_ref)\n state = ring.State.create(sys) \n\n xs = []\n T = 5.0\n for t in range(int(T / sys.dt)):\n controller_state, taus = jax.jit(controller.apply)(controller_state, sys, state)\n state = jax.jit(ring.step)(sys, state, taus)\n xs.append(state.x)\n return xs\n
xs = simulate_pd_control(sys, P, D)\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 125/125 [00:00<00:00, 165.26it/s]\n
\n
This browser does not support the video tag. Note the steady state error. This is because we have gravity and no Integral part (so no PID control).
If we remove gravity the steady state error also vanishes (as is expected.)
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)\nxs = simulate_pd_control(sys_nograv, P, D)\nshow_video(sys_nograv, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 125/125 [00:00<00:00, 132.06it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"notebooks/control/#balance-an-inverted-pendulum-on-a-cart","title":"Balance an inverted Pendulum on a cart","text":""},{"location":"notebooks/control/#pd-control","title":"PD Control","text":""},{"location":"notebooks/custom_joint_type/","title":"Custom joint type","text":"Note
This example is available as a jupyter notebook here.
In this notebook we will define a new joint type that is a hinge joint with a random joint axes direction.
It will also support dynamical simulation.
import ring\nfrom ring import maths, base\n\nimport jax\nimport jax.numpy as jnp\n\nimport mediapy as media\n\nfrom ring.algorithms.jcalc import _draw_rxyz\n
We will give this new joint type the identifier rr
(random revolute). Although it actually already exists in the library, but we can overwrite it.
# we use such a `params` input to specify the joint-axes, if we later then randomize the attribute of the system object\n# we will have the effect of a hinge joint with a randomized joint axes direction\n# here we tell the library how it should initialize this `params` PyTree\ndef _draw_random_joint_axis(key):\n return maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key))\n\ndef _rr_init_joint_params(key):\n return dict(joint_axes=_draw_random_joint_axis(key))\n\n# next, we tell the library how it can randomly draw a trajectory for its generalized coordinate; the hinge joint angle\ndef _rr_transform(q, params):\n # here we use this `params` object\n axis = params[\"joint_axes\"]\n q = jnp.squeeze(q)\n rot = maths.quat_rot_axis(axis, q)\n return ring.Transform.create(rot=rot)\n\n# this tells the library how to dynamically simulate the type of joint\ndef _motion_fn(params):\n return base.Motion.create(ang=params[\"joint_axes\"])\n\n# now, we can put it all together into a new `x_xy.JointModel`\nrr_joint = ring.JointModel(_rr_transform, motion=[_motion_fn], rcmg_draw_fn=_draw_rxyz, init_joint_params=_rr_init_joint_params)\n\n# and then we register the joint; Note that `overwrite`=True, because it already exists; that way you can e.g. overwrite the\n# default joint types such as the free joint\nring.register_new_joint_type(\"rr\", rr_joint, q_width=1, qd_width=1, overwrite=True)\n
xml_str = \"\"\"\n<x_xy>\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<body damping=\".01\" joint=\"rr\" name=\"pendulum\" pos=\"0 0 0.5\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.5 0.1 0.1\" mass=\"0.5\" pos=\"0.25 0 0\" type=\"box\"></geom>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\n# this seed determines (among other things) the randomness of the joint-axes direction\n# via the above specified `_rr_init_joint_params`\nseed: int = 2\nsys = ring.System.create(xml_str, seed=seed)\n
state = ring.State.create(sys)\nxs = []\nfor t in range(500):\n state = jax.jit(ring.step)(sys, state)\n xs.append(state.x)\n
sys.links.joint_params\n
\n{'rr': {'joint_axes': Array([[ 0.41278404, -0.6329913 , 0.65492845]], dtype=float32)},\n 'default': Array([], shape=(1, 0), dtype=float32)}
\n
def show_video(sys, xs: list[ring.Transform]):\n frames = sys.render(xs, render_every_nth=4)\n media.show_video(frames, fps=25)\n\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 125/125 [00:00<00:00, 158.41it/s]\n
\n
This browser does not support the video tag. the class x_xy.RCMG
already has the built-in flag randomize_joint_params
which can be toggled in order to use the user-provided logic _rr_init_joint_params
for randomizing the joint parameters
(X, y), (key, q, x, _) = ring.RCMG(sys, randomize_joint_params=True, keep_output_extras=True).to_list()[0]\n
\neager data generation: 1it [00:02, 2.24s/it]\n
\n
show_video(sys, x)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1500/1500 [00:08<00:00, 169.89it/s]\n
\n
This browser does not support the video tag. but for dynamic_simulation
flag to work we additional need to specify the function ring.JointModel.p_control_term
print(rr_joint.p_control_term)\n
\nNone\n
\n
try:\n (X, y), (key, q, x, _) = ring.RCMG(sys, randomize_joint_params=True, keep_output_extras=True, dynamic_simulation=True).to_list()[0]\nexcept NotImplementedError:\n print(\"NotImplementedError: Please specify `JointModel.p_control_term` for joint type `rr`\")\n
\neager data generation: 0it [00:00, ?it/s]
\n
\nNotImplementedError: Please specify `JointModel.p_control_term` for joint type `rr`\n
\n
\n\n
\n
\n
"},{"location":"notebooks/custom_joint_type/#defining-a-custom-joint-type-that-supports-dynamical-simulation","title":"Defining a custom Joint Type that supports dynamical simulation","text":""},{"location":"notebooks/error_quaternion/","title":"Error quaternion","text":"Note
This example is available as a jupyter notebook here.
In this notebook we will talk about what functions you need to do ML with quaternions. After all the purpose of this library is to create training data.
Typically, this involves quaternions as target values (to be predicted), similar to an orientation estimation filter (like VQF).
So, suppose you want to train some ML model that predicts a quaternion \\(\\hat{q} = f_\\theta(X)\\).
import ring\nimport jax \nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n
# suppose a 6D IMU input\nfeature_dim = 6\n\nparams = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))\ndef neural_network(params, X):\n q_unnormalized = params@X\n norm = jnp.linalg.norm(q_unnormalized)\n return q_unnormalized / norm\n\n\ndef loss_fn(params, X, y):\n q, qhat = y, neural_network(params, X)\n # squared angle error\n return ring.maths.angle_error(q, qhat)**2\n
But this is dangerous as this might lead to NaNs.
X = jnp.zeros((6,))\ny = jnp.array([1.0, 0, 0, 0])\nloss_fn(params, X, y)\n
\nArray(nan, dtype=float32)
\n
We could try to fix is by adding a small number in the divison.
# suppose a 6D IMU input\nfeature_dim = 6\n\nparams = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))\ndef neural_network(params, X):\n q_unnormalized = params@X\n norm = jnp.linalg.norm(q_unnormalized)\n eps = 1e-8\n return q_unnormalized / (norm + eps)\n\n\ndef loss_fn(params, X, y):\n q, qhat = y, neural_network(params, X)\n # squared angle error\n return ring.maths.angle_error(q, qhat)**2\n\nX = jnp.zeros((6,))\ny = jnp.array([1.0, 0, 0, 0])\nloss_fn(params, X, y)\n
\nArray(0., dtype=float32)
\n
But, still the gradient required for backpropagation gives NaNs.
jax.grad(loss_fn)(params, X, y)\n
\nArray([[nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan]], dtype=float32)
\n
The solution is a little involved. TLDR; Use x_xy.maths.safe_normalize
# suppose a 6D IMU input\nfeature_dim = 6\n\nparams = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))\ndef neural_network(params, X):\n q_unnormalized = params@X\n return ring.maths.safe_normalize(q_unnormalized)\n\n\ndef loss_fn(params, X, y):\n q, qhat = y, neural_network(params, X)\n # squared angle error\n return ring.maths.angle_error(q, qhat)**2\n\nX = jnp.zeros((6,))\ny = jnp.array([1.0, 0, 0, 0])\nloss_fn(params, X, y)\n
\nArray(0., dtype=float32)
\n
jax.grad(loss_fn)(params, X, y)\n
\nArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32)
\n
Let's take a closer look at the function x_xy.maths.angle_error
which was used in the loss_fn
in the above.
What is the behaviour of the error function (sort of the metric) between two quaternions as one approaches the other?
A first implementation might look like this:
def quat_error(q, qhat):\n q_error = ring.maths.quat_mul(ring.maths.quat_inv(q), qhat)\n phi = 2 * jnp.arccos(q_error[0])\n return jnp.abs(phi)\n
Let's reduce this function to the critical operation phi = ...
and let's assume, without loss of generality, that the target quaternion is the identity quaternion.
Then, this effectively becomes about extracting the angle from a quaternion safely.
def quat_angle(q):\n return 2 * jnp.arccos(q[0])\n
input_angles = jnp.linspace(-0.005, 0.005, num=1000)\n\ndef input_to_output_angles_incorrect(angle):\n q = ring.maths.quat_rot_axis(jnp.array([1.0, 0, 0]), angle)\n return quat_angle(q)\n\ndef input_to_output_angles_correct(angle):\n q = ring.maths.quat_rot_axis(jnp.array([1.0, 0, 0]), angle)\n return ring.maths.quat_angle(q)\n
plt.plot(input_angles, jax.vmap(input_to_output_angles_incorrect)(input_angles), label=\"incorrect\")\nplt.plot(input_angles, jax.vmap(input_to_output_angles_correct)(input_angles), label=\"correct\")\nplt.legend()\nplt.show()\n
As one might expect, the gradients are also much more stable.
plt.plot(input_angles, jax.vmap(jax.grad(input_to_output_angles_incorrect))(input_angles), label=\"incorrect\")\nplt.plot(input_angles, jax.vmap(jax.grad(input_to_output_angles_correct))(input_angles), label=\"correct\")\nplt.legend()\nplt.show()\n
\n
"},{"location":"notebooks/error_quaternion/#the-error-quaternion-required-for-ml-purposes","title":"The error quaternion (required for ML purposes)","text":""},{"location":"notebooks/error_quaternion/#how-to-get-a-quaternion-as-network-output","title":"How to get a quaternion as network output?","text":"That's easy enough. You normalize a four dimensional vector.
"},{"location":"notebooks/error_quaternion/#a-closer-look-at-the-function-x_xymathsangle_error","title":"A closer look at the functionx_xy.maths.angle_error
","text":""},{"location":"notebooks/error_quaternion/#pytorch-library-for-quaternion-operations","title":"Pytorch library for quaternion operations","text":"These functions are for JAX, but the following should work for PyTorch -> https://naver.github.io/roma/
"},{"location":"notebooks/experimental_data/","title":"Experimental data","text":"Note
This example is available as a jupyter notebook here.
import ring\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport mediapy as media\n\ndef show_video(sys: ring.System, xs: ring.Transform) -> None:\n assert sys.dt == 0.01\n # only render every fourth to get a framerate of 25 fps\n frames = sys.render(xs, camera=\"c\", height=480, width=640, render_every_nth=4,\n add_cameras={-1: '<camera mode=\"targetbody\" name=\"c\" pos=\".5 -.5 1.25\" target=\"3\"></camera>'})\n media.show_video(frames, fps=25)\n
Experimental data and system definitions of the experimental setup are located in..
from ring import exp\n
Multiple experimental trials are available. They have exp_id
s and motion_start
s and motion_stop
s
exp_id = \"S_06\"\nsys = exp.load_sys(exp_id)\n
Let's first take a look at the system that was used in the experiments.
state = ring.State.create(sys)\n# update the maximal coordinates\nxs = ring.algorithms.forward_kinematics(sys, state)[1].x\n
show_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1/1 [00:00<00:00, 7.88it/s]\n
\n
This browser does not support the video tag. As you can see a five segment kinematic chain was moved, and for each segment IMU measurements and OMC ground truth is available.
Let's load this (no simulated) IMU and OMC data.
# `canonical` is the identifier of the first motion pattern performed in this trial\n# `shaking` is the identifier of the last motion pattern performed in this trial\nmotion_start = \"canonical\"\ndata = exp.load_data(exp_id, motion_start=motion_start)\n
data.keys()\n
\ndict_keys(['seg1', 'seg2', 'seg3', 'seg4', 'seg5'])
\n
data[\"seg1\"].keys()\n
\ndict_keys(['imu_flex', 'imu_rigid', 'marker1', 'marker2', 'marker3', 'marker4', 'quat'])
\n
data[\"seg1\"][\"imu_rigid\"].keys()\n
\ndict_keys(['acc', 'gyr', 'mag'])
\n
The quaternion quat
is to be interpreted as the rotation from segment to an arbitrary OMC inertial frame.
The position marker1
is to be interpreted as the position vector from arbitrary OMC inertial frame to a specific marker (marker 1) on the respective segment (vector given in the OMC inertial frame).
Then, for each segment actually two IMUs are attached to it. One is rigidly attached, one is non-rigidly attached (via foam).
Also, how long is the trial?
data[\"seg1\"][\"marker1\"].shape\n
\n(14200, 3)
\n
It's 325 seconds of data.
Let's take a look at the motion of the whole trial.
To render it, we need maximal coordinates xs
of all links in the system.
X, y, xs, xs_noimu = exp.benchmark_fn(exp.IMTP(segments=sys.findall_segments()), exp_id, motion_start)\n
show_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 3550/3550 [00:30<00:00, 118.12it/s]\n
\n
This browser does not support the video tag. Perfect. This is a rendered animation of the real experimental motion that was performed. You can see that the spacing between segments is not perfect.
This is because in our idealistic system model joints have no spatial dimension but in reality they have. The entire setup is 3D printed, and the joints are also several centimeters long.
The segments are 20cm long.
We can use this experimental data to validate our simulated approaches or validate ML models that are learned on simulated training data.
\n
"},{"location":"notebooks/experimental_data/#loading-and-working-with-experimental-data","title":"Loading and working with experimental data","text":""},{"location":"notebooks/getting_started/","title":"Getting started","text":"Note
This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'ring @ git+https://github.com/SimiPixel/ring'\")\n os.system(\"pip install --quiet mediapy\")\n os.system(\"pip install --quiet matplotlib\")\n
import ring\n# automatically detects colab or not\nring.utils.setup_colab_env()\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport mediapy as media\n
Systems are defined with the following xml syntax.
xml_str = \"\"\"\n<x_xy model=\"double_pendulum\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"2\" euler=\"0 90 0\" joint=\"ry\" name=\"upper\" pos=\"0 0 2\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body damping=\"2\" joint=\"ry\" name=\"lower\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n
With this xml description of the system, we are ready to load the system using load_sys_from_str
. We can also save this to a text-file double_pendulum.xml
and load with load_sys_from_xml
.
sys = ring.System.create(xml_str)\n
sys.model_name\n
\n'double_pendulum'
\n
System objects have many attributes. You may refer to the API documentation for more details.
sys.link_names\n
\n['upper', 'lower']
\n
Let's start with the most obvious. A physical simulation. We refer to it as \"dynamical simulation\", in contrast to what we do a little later which is a purely kinematic simulation.
First, we have to create the dynamical state of the system. It is defined by the all degrees of freedom in the system and their velocities. Here, we have two revolute joints (one degree of freedom). Thus, the minimal coordinates vector \\(q\\) and minimal velocity vector \\(q'\\) has two dimensions.
state = ring.State.create(sys)\n
state.q\n
\nArray([0., 0.], dtype=float32)
\n
state.qd\n
\nArray([0., 0.], dtype=float32)
\n
next_state = ring.step(sys, state)\n
Massive speedups if we use jax.jit
to jit-compile the function.
%timeit ring.step(sys, state)\n
\n196 ms \u00b1 5.62 ms per loop (mean \u00b1 std. dev. of 7 runs, 1 loop each)\n
\n
%timeit jax.jit(ring.step)(sys, state)\n
\n90.2 \u00b5s \u00b1 41.4 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 1 loop each)\n
\n
Let's unroll the dynamics for multiple timesteps.
T = 10.0\nxs = []\nfor _ in range(int(T / sys.dt)):\n state = jax.jit(ring.step)(sys, state)\n xs.append(state.x)\n
Next, let's render the frames and create an animation.
frames = sys.render(xs, camera=\"targetfar\")\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 176.53it/s]\n
\n
def show_video(frames: list[np.ndarray], dt: float):\n assert dt == 0.01\n # frames are at 100 Hz, but let's create an animation at 25Hz\n media.show_video([frames[i][..., :3] for i in range(0, len(frames), 4)], fps=25)\n\nshow_video(frames, sys.dt)\n
This browser does not support the video tag. Hmm, pretty boring. Let's get the pendulum into an configuration with some potential energy.
All we have to change is the initial state state.q
.
state = ring.State.create(sys, q=jnp.array([jnp.pi / 2, 0]))\n
T = 10.0\nxs = []\nfor _ in range(int(T / sys.dt)):\n state = jax.jit(ring.step)(sys, state)\n xs.append(state.x)\n\nframes = sys.render(xs, camera=\"targetfar\")\nshow_video(frames, sys.dt)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 177.75it/s]\n
\n
This browser does not support the video tag. That's more like it!
Next, we will take a look at \"kinematic simulation\".
Let's start with why you would want this.
Imagine we want to learn a filter that estimates some quantity of interest from some sensor input.
Then, we could try to create many random motions, record the measured sensor input, and the ground truth quantity of interest target values.
This is then used as training data for a Machine Learning model.
The general interface to kinematic simulation is via x_xy.RCMG
.
This class can then create - a function (of type Generator
) that maps a PRNG seed to, e.g., X, y
data. - a list of data - data on disk (saved via pickle or hdf5)
(X, y), (key, q, xs, _) = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), keep_output_extras=True).to_list()[0]\n
\neager data generation: 1it [00:01, 1.95s/it]\n
\n
frames = sys.render(xs, camera=\"targetfar\")\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 177.14it/s]\n
\n
This is now completely random, but unphysical motion. It's only kinematics, but that is okay for creating training data.
show_video(frames, sys.dt)\n
This browser does not support the video tag. We are interested in simulating IMU data as input X
, and estimating quaternions as target y
.
We can easily simulate an IMU with only the trajectory of maximal coordinates xs
.
Suppose, we want to simulate an IMU right that is placed on the lower
segment and right at the revolute joint.
This is exactly where the coordinate system of the lower
segment is placed.
Right now the xs
trajectory contains both coordinate sytems of upper
and lower
.
# (n_timesteps, n_links, 3)\nxs.pos.shape\n
\n(1000, 2, 3)
\n
# (n_timesteps, n_links, 4)\nxs.rot.shape\n
\n(1000, 2, 4)
\n
From the axis with length two, the 0th entry is for upper
and the 1st entry is for lower
.
sys.name_to_idx(\"upper\")\n
\n0
\n
sys.name_to_idx(\"lower\")\n
\n1
\n
xs_lower = xs.take(1, axis=1)\n
imu_lower = ring.algorithms.imu(xs_lower, sys.gravity, sys.dt)\n
imu_lower.keys()\n
\ndict_keys(['acc', 'gyr'])
\n
plt.grid()\nplt.plot(np.arange(0, 10.0, step=sys.dt), imu_lower[\"gyr\"], label=[\"x\", \"y\", \"z\"])\nplt.ylabel(\"gyro [rad / s]\")\nplt.xlabel(\"time [s]\")\nplt.legend()\nplt.show()\n
As you can see it's a two-dimensional problem, which is why only one (y
) is non-zero.
Let's consider a larger kinematic chain in free 3D space.
xml_str = \"\"\"\n<x_xy model=\"three_segment_kinematic_chain\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg2\" pos=\"0 0 2\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg1\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.5 0 0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rz\" name=\"seg3\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.5 0 -0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = ring.System.create(xml_str)\ndata = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), add_X_imus=True,\n add_y_relpose=True, keep_output_extras=True).to_list()\n\n# with `keep_output_extras` really everything one could possibly imagine is returned\n(X, y), (key, qs, xs, sys_mod) = data[0]\n\nframes = sys.render(xs, camera=\"targetfar\")\nshow_video(frames, sys.dt)\n
\neager data generation: 1it [00:05, 5.23s/it]\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1000/1000 [00:05<00:00, 182.81it/s]\n
\n
This browser does not support the video tag. The two orange boxes on segment 1 and segment 3 are modelling our two IMUs. This will be the network's input X
data.
As target we will try to estimate both relative orientations as y
data.
X.keys()\n
\ndict_keys(['seg1', 'seg2', 'seg3'])
\n
X[\"seg1\"].keys()\n
\ndict_keys(['acc', 'gyr'])
\n
y.keys()\n
\ndict_keys(['seg1', 'seg3'])
\n
plt.grid()\nplt.plot(np.arange(0, 10.0, step=sys.dt), X[\"seg1\"][\"gyr\"], label=[\"x\", \"y\", \"z\"])\nplt.ylabel(\"gyro [rad / s]\")\nplt.xlabel(\"time [s]\")\nplt.title(\"IMU 1 Gyroscope\")\nplt.legend()\nplt.show()\n
Now, the IMU is non-zero in all three x/y/z
components.
plt.grid()\nplt.plot(np.arange(0, 10.0, step=sys.dt), y[\"seg1\"], label=[\"w\", \"x\", \"y\", \"z\"])\nplt.xlabel(\"time [s]\")\nplt.title(\"Relative quaternion from seg2 to seg1\")\nplt.legend()\nplt.show()\n
Note how the relative quaternion is only around the y-axis. Can you see why? (Hint: Check the defining xml_str
.)
\n
"},{"location":"notebooks/getting_started/#dynamical-simulation","title":"Dynamical Simulation","text":""},{"location":"notebooks/getting_started/#kinematic-simulation","title":"Kinematic Simulation","text":""},{"location":"notebooks/getting_started/#x-y-training-data-attaching-sensors","title":"X, y
Training data / Attaching sensors","text":""},{"location":"notebooks/imu_modeling/","title":"Imu modeling","text":"from x_xy.subpkgs import exp\nimport matplotlib.pyplot as plt\nimport jax\nimport x_xy\nimport jax.numpy as jnp\n\nhz = 100\nmarkerMap = {\n \"seg1\": 2,\n \"seg5\": 2,\n \"seg2\": 1,\n \"seg3\": 2,\n \"seg4\": 4\n}\n\ndef load_data(seg: str, t1: float, t2: float, motion: str = \"fast\"):\n\n data = exp.load_data(\"S_06\", motion)[seg]\n\n # extract a small window from long time series for plotting\n pos, rot, imu_data = jax.tree_map(lambda arr: arr[int(t1 * hz): int(t2 * hz)], \n (data[f\"marker{markerMap[seg]}\"], data[\"quat\"], data[\"imu_rigid\"]))\n rot = x_xy.maths.quat_inv(rot)\n\n # maximal coordinates of segment, there is (almost) no sensor-to-segment orientation\n xs = x_xy.Transform.create(pos, rot)\n return pos, rot, xs, imu_data\n\n\nt1, t2 = 3.0, 9.0\npos, rot, xs, imu_data = load_data(\"seg1\", t1, t2)\n
Remove gravity from accelerometer to better compare.
def linear_acceleration(xs: x_xy.Transform, acc: jax.Array) -> jax.Array:\n q_E2Imu = xs.rot\n q_Imu2E = x_xy.maths.quat_inv(q_E2Imu)\n gravity = jnp.array([0, 0, 9.81])\n acc_E_nograv = x_xy.maths.rotate(acc, q_Imu2E) - gravity\n return x_xy.maths.rotate(acc_E_nograv, q_E2Imu)\n\nimu_data[\"acc\"] = linear_acceleration(xs, imu_data[\"acc\"])\n
def plot_imu(imu_data: dict):\n imu_data = jax.tree_map(lambda arr: arr[:-100], imu_data.copy())\n fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n ts = jnp.arange(t1, t2 - 1.0, step=1 / hz)\n axes[0].plot(ts, imu_data[\"acc\"], label=[\"x\", \"y\", \"z\"])\n axes[1].plot(ts, imu_data[\"gyr\"], label=[\"x\", \"y\", \"z\"])\n for ax in axes:\n ax.grid()\n ax.set_xlabel(\"time [s]\")\n ax.legend()\n axes[0].set_title(\"Acc\")\n axes[1].set_title(\"Gyr\")\n\nplot_imu(imu_data)\n
imu_data = x_xy.imu(xs, gravity=jnp.zeros((3,)), dt=1 / hz)\nplot_imu(imu_data)\n
Accelerometer doesn't look too great! We need low-pass filtering. Two options:
imu_data = x_xy.imu(xs, gravity=jnp.zeros((3,)), dt=1 / hz, quasi_physical=True)\nplot_imu(imu_data)\n
imu_data = x_xy.imu(xs, gravity=jnp.zeros((3,)), dt=1 / hz, low_pass_filter_pos_f_cutoff=15.0, low_pass_filter_rot_alpha=0.55)\nplot_imu(imu_data)\n
from scipy.optimize import minimize\n\ndef optimize_parameters(seg: str, motion: str):\n\n # include all `fast` data in the optimization\n t1, t2 = 0.0, 500.0\n pos, rot, xs, imu_data = load_data(seg, t1, t2, motion)\n imu_data[\"acc\"] = linear_acceleration(xs, imu_data[\"acc\"])\n\n @jax.jit\n def objective(params):\n f_cutoff, alpha, offset = params\n\n # probably move about 5cm negative x-axis in local CS for e.g. segment 1\n pos_offset = x_xy.maths.rotate(x_xy.maths.rotate(pos, rot) + jnp.array([offset, 0, 0]), x_xy.maths.quat_inv(rot))\n xs_offset = xs.replace(pos=pos_offset)\n imu = x_xy.imu(xs_offset, jnp.zeros((3,)), 1 / hz, low_pass_filter_pos_f_cutoff=f_cutoff, low_pass_filter_rot_alpha=alpha)\n\n return jnp.mean((imu_data[\"acc\"] - imu[\"acc\"])**2) + jnp.mean((imu_data[\"gyr\"] - imu[\"gyr\"])**2)\n\n return minimize(objective, jnp.array([5.0, 1.0, 0.0]), method=\"Nelder-Mead\")\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"fast\"))\n
\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.7932413816452026\n x: [ 1.135e+01 1.034e+01 1.147e-01]\n nit: 147\n nfev: 287\n final_simplex: (array([[ 1.135e+01, 1.034e+01, 1.147e-01],\n [ 1.135e+01, 1.034e+01, 1.147e-01],\n [ 1.135e+01, 1.034e+01, 1.147e-01],\n [ 1.135e+01, 1.034e+01, 1.147e-01]]), array([ 7.932e-01, 7.932e-01, 7.933e-01, 7.933e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.40395233035087585\n x: [ 1.123e+01 1.112e+01 1.159e-01]\n nit: 98\n nfev: 198\n final_simplex: (array([[ 1.123e+01, 1.112e+01, 1.159e-01],\n [ 1.123e+01, 1.112e+01, 1.159e-01],\n [ 1.123e+01, 1.112e+01, 1.159e-01],\n [ 1.123e+01, 1.112e+01, 1.159e-01]]), array([ 4.040e-01, 4.040e-01, 4.040e-01, 4.040e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.37816864252090454\n x: [ 1.190e+01 1.226e+01 1.195e-01]\n nit: 121\n nfev: 238\n final_simplex: (array([[ 1.190e+01, 1.226e+01, 1.195e-01],\n [ 1.190e+01, 1.226e+01, 1.195e-01],\n [ 1.190e+01, 1.226e+01, 1.195e-01],\n [ 1.190e+01, 1.226e+01, 1.195e-01]]), array([ 3.782e-01, 3.782e-01, 3.782e-01, 3.782e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.541861355304718\n x: [ 1.131e+01 1.372e+01 1.160e-01]\n nit: 173\n nfev: 330\n final_simplex: (array([[ 1.131e+01, 1.372e+01, 1.160e-01],\n [ 1.131e+01, 1.372e+01, 1.160e-01],\n [ 1.131e+01, 1.372e+01, 1.160e-01],\n [ 1.131e+01, 1.372e+01, 1.160e-01]]), array([ 5.419e-01, 5.419e-01, 5.419e-01, 5.419e-01]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.6123160123825073\n x: [ 1.106e+01 9.883e+00 1.211e-01]\n nit: 102\n nfev: 202\n final_simplex: (array([[ 1.106e+01, 9.883e+00, 1.211e-01],\n [ 1.106e+01, 9.883e+00, 1.211e-01],\n [ 1.106e+01, 9.883e+00, 1.211e-01],\n [ 1.106e+01, 9.883e+00, 1.211e-01]]), array([ 6.123e-01, 6.123e-01, 6.123e-01, 6.123e-01]))\n
\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"slow1\"))\n
\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.09304095804691315\n x: [ 9.910e+00 3.885e-01 1.136e-01]\n nit: 111\n nfev: 211\n final_simplex: (array([[ 9.910e+00, 3.885e-01, 1.136e-01],\n [ 9.910e+00, 3.885e-01, 1.136e-01],\n [ 9.910e+00, 3.885e-01, 1.136e-01],\n [ 9.910e+00, 3.885e-01, 1.136e-01]]), array([ 9.304e-02, 9.305e-02, 9.305e-02, 9.305e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.02368384227156639\n x: [ 1.008e+01 3.732e-01 1.332e-01]\n nit: 97\n nfev: 190\n final_simplex: (array([[ 1.008e+01, 3.732e-01, 1.332e-01],\n [ 1.008e+01, 3.732e-01, 1.332e-01],\n [ 1.008e+01, 3.732e-01, 1.332e-01],\n [ 1.008e+01, 3.732e-01, 1.332e-01]]), array([ 2.368e-02, 2.369e-02, 2.369e-02, 2.369e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.01580578088760376\n x: [ 8.666e+00 3.510e-01 1.343e-01]\n nit: 111\n nfev: 219\n final_simplex: (array([[ 8.666e+00, 3.510e-01, 1.343e-01],\n [ 8.666e+00, 3.510e-01, 1.343e-01],\n [ 8.666e+00, 3.510e-01, 1.343e-01],\n [ 8.667e+00, 3.510e-01, 1.343e-01]]), array([ 1.581e-02, 1.581e-02, 1.581e-02, 1.581e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.01700039766728878\n x: [ 8.336e+00 3.602e-01 1.210e-01]\n nit: 108\n nfev: 208\n final_simplex: (array([[ 8.336e+00, 3.602e-01, 1.210e-01],\n [ 8.336e+00, 3.601e-01, 1.210e-01],\n [ 8.335e+00, 3.601e-01, 1.210e-01],\n [ 8.335e+00, 3.602e-01, 1.210e-01]]), array([ 1.700e-02, 1.700e-02, 1.700e-02, 1.700e-02]))\n message: Optimization terminated successfully.\n success: True\n status: 0\n fun: 0.10861615836620331\n x: [ 6.784e+00 3.782e-01 5.929e-04]\n nit: 50\n nfev: 107\n final_simplex: (array([[ 6.784e+00, 3.782e-01, 5.929e-04],\n [ 6.784e+00, 3.782e-01, 5.930e-04],\n [ 6.784e+00, 3.782e-01, 5.930e-04],\n [ 6.784e+00, 3.782e-01, 5.930e-04]]), array([ 1.086e-01, 1.086e-01, 1.086e-01, 1.086e-01]))\n
\n
\n
"},{"location":"notebooks/imu_modeling/#on-what-imus-measure","title":"On \"what IMUs measure\"","text":""},{"location":"notebooks/imu_modeling/#real-world-imu","title":"Real-world IMU","text":""},{"location":"notebooks/imu_modeling/#vanilla-simulated-imu","title":"Vanilla simulated IMU","text":""},{"location":"notebooks/imu_modeling/#quasi-physical-simulation-strategy","title":"Quasi-physical simulation strategy","text":""},{"location":"notebooks/imu_modeling/#butterworth-filtering","title":"Butterworth filtering","text":""},{"location":"notebooks/imu_modeling/#optimize-low-pass-filter-parameters","title":"Optimize low-pass-filter parameters","text":""},{"location":"notebooks/knee_joint_translational_dof/","title":"Knee joint translational dof","text":"This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'x_xy[muj] @ git+https://github.com/SimiPixel/x_xy_v2'\")\n os.system(\"pip install --quiet mediapy\")\n os.system(\"pip install --quiet matplotlib\")\n os.system(\"pip install --quiet dm-haiku\")\n
import x_xy\n# automatically detects colab or not\nx_xy.utils.setup_colab_env()\n\nimport jax\nimport jax.numpy as jnp\nimport haiku as hk\nimport mediapy as media\nimport tree_utils\n
MAX_TRANSLATION = 0.1\nROM_MIN_RAD = 0.0\nROM_MAX_RAD = jnp.pi\n\ndef build_mlp_knee(key: jax.random.PRNGKey = jax.random.PRNGKey(1)):\n\n @hk.without_apply_rng\n @hk.transform\n def mlp(x):\n net = hk.nets.MLP([10, 10, 2], activation=jnp.tanh, w_init=hk.initializers.RandomNormal())\n # normalize the x input; [0, 1]\n x = (x + ROM_MIN_RAD) / (ROM_MAX_RAD - ROM_MIN_RAD)\n # center the x input; [-0.5, 0.5]\n x = (x - 0.5)\n return net(x)\n\n example_q = jnp.zeros((1,))\n params = mlp.init(key, example_q)\n\n def forward(params, q: jax.Array):\n return jax.nn.sigmoid(mlp.apply(params, q)) * MAX_TRANSLATION\n\n return params, forward\n\ndef _knee_init_joint_params(key):\n return build_mlp_knee(key)[0]\n\n\ndef transform_fn_knee(q: jax.Array, params: jax.Array) -> x_xy.Transform:\n forward = build_mlp_knee()[1]\n pos = jnp.concatenate((forward(params, q), jnp.array([0.0])))\n axis = jnp.array([0, 0, 1.0])\n rot = x_xy.maths.quat_rot_axis(axis, jnp.squeeze(q))\n return x_xy.Transform(pos, rot)\n\n\ndef draw_fn_knee(config: x_xy.MotionConfig, key_t, key_value, dt, params):\n qs = x_xy.algorithms.jcalc._draw_rxyz(config, key_t, key_value, dt, params)\n # rom constraints\n return (qs / (2 * jnp.pi) + 0.5) * (ROM_MAX_RAD - ROM_MIN_RAD) + ROM_MIN_RAD\n\nx_xy.register_new_joint_type(\"knee\", x_xy.JointModel(transform_fn_knee, rcmg_draw_fn=draw_fn_knee, init_joint_params=_knee_init_joint_params), 1, 0)\n
HIP_REVOLUTE_JOINT = True\n\nxml_str = f\"\"\"\n<x_xy>\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<geom dim=\"0.15\" type=\"xyz\"></geom>\n<body euler=\"90 90 0\" joint=\"py\" name=\"_femur\" pos=\"0.5 0.5 0.8\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<body \"frozen\"}\"=\"\" else=\"\" hip_revolute_joint=\"\" if=\"\" joint=\"{\" name=\"femur\" rz\"=\"\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.05 0.4\" euler=\"0 90 0\" mass=\"10\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"0.2 0 0.06\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0\" type=\"box\"></geom>\n</body>\n<body joint=\"knee\" name=\"tibia\" pos=\"0.4 0 0\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.04 0.4\" euler=\"0 90 0\" mass=\"10\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.2 0 0.06\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0\" type=\"box\"></geom>\n</body>\n<geom dim=\"0.025 0.2 0.05\" mass=\"5.0\" pos=\"0.45 -.1 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = x_xy.load_sys_from_str(xml_str)\n
def finalize_fn(key, qs, xs: x_xy.Transform, sys: x_xy.System):\n X = {}\n for imu in [\"imu1\", \"imu2\"]:\n xs_imu = xs.take(sys.name_to_idx(imu), axis=1)\n X[imu] = {}\n X[imu][\"pos\"] = xs_imu.pos\n X[imu][\"quat\"] = xs_imu.rot\n X[imu][\"imu\"] = x_xy.imu(xs_imu, sys.gravity, sys.dt)\n\n params = tree_utils.tree_slice(sys.links.joint_params[\"knee\"], sys.name_to_idx(\"tibia\"))\n return qs, xs, X, params\n\ndata = x_xy.build_generator(sys, x_xy.MotionConfig(t_min=0.1, t_max=0.75, T=30), finalize_fn=finalize_fn, randomize_joint_params=True, eager=True, aslist=True, seed=1, sizes=32)\n
\neager data generation: 1it [00:07, 7.23s/it]\n
\n
idx = 5\nqs, xs, X, params = data[idx]\n
import matplotlib.pyplot as plt\n\n\nphi = jnp.linspace(0.0, jnp.pi)[:, None]\n# meter -> centimeter\ntrans_x, trans_y = jax.vmap(lambda arr: build_mlp_knee()[1](params, arr))(phi).T * 100\nplt.scatter(trans_x, trans_y, c=phi, cmap=\"coolwarm\")\nplt.colorbar()\nplt.grid()\nplt.xlabel(\"x translation [cm]\")\nplt.ylabel(\"y translation [cm]\")\n
\nText(0, 0.5, 'y translation [cm]')
\n
media.show_video(x_xy.render(sys, [xs[i] for i in range(0, xs.shape(), 4)], camera=\"target\", width=1280, height=720), fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 750/750 [00:05<00:00, 133.26it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"notebooks/knee_joint_translational_dof/#registering-a-knee-joint-type","title":"Registering a Knee Joint Type","text":""},{"location":"notebooks/machine_learning/","title":"Machine learning","text":"Note
This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'ring @ git+https://github.com/SimiPixel/ring'\")\n os.system(\"pip install --quiet mediapy\")\n\nimport ring\n# automatically detects colab or not\nring.utils.setup_colab_env()\n\nimport mediapy\nimport jax.numpy as jnp\nimport tree_utils\nfrom ring import exp, sim2real, ml\n
imtp = exp.IMTP([\"seg2\", \"seg3\", \"seg4\"], sparse=True, joint_axes=True)\nexp_id = \"S_04\"\nmotion = \"thomas_fast\"\nringnet = ml.RING_ICML24()\nerrors, X, y, yhat, xs, xs_noimu = exp.benchmark_fn(imtp, exp_id, motion, ringnet, warmup=5.0)\nsys = imtp.sys(exp_id)\nframes = sys.render_prediction(xs, yhat, stepframe=4, transparent_segment_to_root=False, width=640, height=480, camera=\"c\", \n add_cameras={-1: '<camera mode=\"targetbody\" name=\"c\" pos=\".5 -.5 1.25\" target=\"3\"></camera>',})\n
\nDetected the following sampling rates from `X`: 100.0\n
\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1150/1150 [00:09<00:00, 126.31it/s]\n
\n
errors\n
\n{'seg2': {'mae': 4.0391045, 'std': 2.5817297},\n 'seg3': {'mae': 5.3936434, 'std': 2.996799},\n 'seg4': {'mae': 4.5474825, 'std': 2.3526802}}
\n
mediapy.show_video(frames, fps=25.0)\n
This browser does not support the video tag. \n
"},{"location":"notebooks/magnetometer_modeling/","title":"Magnetometer modeling","text":"from x_xy.subpkgs import exp\nimport matplotlib.pyplot as plt\nimport jax\nimport x_xy\nimport jax.numpy as jnp\nimport numpy as np\n\nhz = 100\nmarkerMap = {\n \"seg1\": 2,\n \"seg5\": 2,\n \"seg2\": 1,\n \"seg3\": 2,\n \"seg4\": 4\n}\n\ndef load_data(seg: str, t1: float, t2: float, motion: str = \"fast\"):\n\n data = exp.load_data(\"S_06\", motion, resample_to_hz=hz)[seg]\n\n # extract a small window from long time series for plotting\n pos, rot, imu_data = jax.tree_map(lambda arr: arr[int(t1 * hz): int(t2 * hz)], \n (data[f\"marker{markerMap[seg]}\"], data[\"quat\"], data[\"imu_rigid\"]))\n rot = x_xy.maths.quat_inv(rot)\n\n # maximal coordinates of segment, there is (almost) no sensor-to-segment orientation\n xs = x_xy.Transform.create(pos, rot)\n return pos, rot, xs, imu_data\n\n\nt1, t2 = 3.0, 9.0\npos, rot, xs, imu_data = load_data(\"seg1\", t1, t2)\n
def plot(*mag_data):\n mag_data = jax.tree_map(lambda arr: arr[:-100], mag_data)\n _, axes = plt.subplots(1, len(mag_data), figsize=(len(mag_data)*6, 4))\n axes = [axes] if not isinstance(axes, np.ndarray) else axes\n ts = jnp.arange(t1, t2 - 1.0, step=1 / hz)\n\n for i, mag in enumerate(mag_data):\n axes[i].plot(ts, mag, label=[\"x\", \"y\", \"z\"])\n axes[i].grid()\n axes[i].set_xlabel(\"time [s]\")\n\n axes[0].legend()\n\nplot(imu_data[\"mag\"])\n
imu_data_sim = x_xy.imu(xs, jnp.array([0, 0, 9.81]), 1/hz, jax.random.PRNGKey(1), has_magnetometer=True, low_pass_filter_rot_alpha=0.5)\nplot(imu_data[\"mag\"], imu_data_sim[\"mag\"])\n
from scipy.optimize import minimize\n\ndef optimize_parameters(seg: str, motion: str):\n t1, t2 = 0.0, 500.0\n pos, rot, xs, imu_data = load_data(seg, t1, t2, motion)\n\n @jax.jit\n def objective(params):\n magvec= params\n #alpha = np.clip(alpha, 0.0, 1.0)\n\n imu_sim = x_xy.imu(xs, jnp.zeros((3,)), 1 / hz, \n low_pass_filter_rot_alpha=0.5, magvec=magvec, has_magnetometer=True)\n\n return jnp.mean((imu_data[\"mag\"] - imu_sim[\"mag\"])**2)\n\n res = minimize(objective, jnp.array([0.0, .7, -.7]), method=\"Nelder-Mead\")\n\n perfect = np.array([0, res.x[1], res.x[2]])\n perfect /= np.linalg.norm(perfect)\n dip_angle = np.arctan2(perfect[1], perfect[2])\n return res.x, np.linalg.norm(res.x), np.rad2deg(dip_angle) - 90\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"slow1\"))\n
\n(array([-0.05669107, 0.13636727, -0.56856133]), 0.5874282070698279, 76.5125997928424)\n(array([ 0.02870585, 0.14479726, -0.5529681 ]), 0.5723320607239446, 75.32629421386511)\n(array([ 0.07342922, 0.27993262, -0.6070893 ]), 0.6725411131056166, 65.24528171284635)\n(array([ 0.06965261, 0.12674702, -0.66338645]), 0.6789682416674281, 79.18339336223758)\n(array([-0.02896293, 0.24820061, -0.55680701]), 0.6103084782009606, 65.9747430877396)\n
\n
for seg in [\"seg1\", \"seg2\", \"seg3\", \"seg4\", \"seg5\"]:\n print(optimize_parameters(seg, \"fast\"))\n
\n(array([-0.08539633, 0.15602869, -0.49032469]), 0.5215896749593268, 72.34814710540633)\n(array([ 0.05422703, 0.13053918, -0.2375643 ]), 0.27643777497021754, 61.211650723518005)\n(array([ 0.17069941, 0.16292433, -0.49904502]), 0.5520222414783663, 71.9195779531855)\n(array([ 0.03610723, 0.06886188, -0.5142856 ]), 0.5201301476029805, 82.37356382782355)\n(array([-0.13417971, 0.32559843, -0.40739543]), 0.5385067931956347, 51.36746475752713)\n
\n
Test optimized magnetic field vector
pos, rot, xs, imu_data = load_data(\"seg1\", t1, t2, \"fast\")\nimu_data_sim = x_xy.imu(xs, jnp.array([0, 0, 9.81]), 1/hz, has_magnetometer=True, low_pass_filter_rot_alpha=0.56,\n magvec=jnp.array([-0.08957149, 0.17059967, -0.59387128]))\nplot(imu_data[\"mag\"], imu_data_sim[\"mag\"])\n
pos, rot, xs, imu_data = load_data(\"seg1\", t1, t2, \"slow1\")\nimu_data_sim = x_xy.imu(xs, jnp.array([0, 0, 9.81]), 1/hz, has_magnetometer=True, low_pass_filter_rot_alpha=0.5,\n magvec=jnp.array([-0.05896413, 0.14859727, -0.6037423 ]), noisy=True, key=jax.random.PRNGKey(7))\nplot(imu_data[\"mag\"], imu_data_sim[\"mag\"])\n
\n
"},{"location":"notebooks/magnetometer_modeling/#magnetometer-modeling","title":"Magnetometer modeling","text":""},{"location":"notebooks/magnetometer_modeling/#real-world-magnetic-field","title":"Real-world Magnetic-field","text":""},{"location":"notebooks/magnetometer_modeling/#optimize-magnetic-field-vector","title":"Optimize Magnetic Field Vector","text":""},{"location":"notebooks/morph_system/","title":"Morph system","text":"import x_xy\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport mediapy as media\n\ndef show_video(sys, xs: x_xy.Transform) -> None:\n assert sys.dt == 0.01\n # only render every fourth to get a framerate of 25 fps\n frames = x_xy.render(sys, [xs[i] for i in range(0, xs.shape(), 4)], camera=\"targetfar\", height=480, width=640)\n # convert rgba to rgb\n frames = [frame[..., :3] for frame in frames]\n media.show_video(frames, fps=25)\n
In this system the middle segment seg2
acts as \"anchor\".
xml_str = \"\"\"\n<x_xy model=\"three_segment_kinematic_chain\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg2\" pos=\"0 0 1\">\n<geom color=\"self\" dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg1\">\n<geom color=\"self\" dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.5 0 0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rz\" name=\"seg3\" pos=\"1 0 0\">\n<geom color=\"self\" dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.5 0 -0.125\">\n<geom color=\"orange\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = x_xy.load_sys_from_str(xml_str)\n\ngen = x_xy.build_generator(sys, x_xy.MotionConfig(T=10.0, t_max=1.5, dang_max_free_spherical=0.1, dpos_max=0.1), _compat=True)\n_, xs = gen(jax.random.PRNGKey(1))\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250/250 [00:01<00:00, 162.68it/s]\n
\n
This browser does not support the video tag. Can you see what i mean? The middle segment has all the \"global rotation and translation\".
Let's move the anchor to seg1
but without changing the xml syntax. This can be done with the subpackage sys_composer
.
from x_xy.subpkgs import sys_composer\n
# the new parents of seg2, seg1, imu1, seg3, imu2 are ...\nnew_parents = [\"seg1\", -1, \"seg1\", \"seg2\", \"seg3\"]\nsys = sys_composer.morph_system(sys, new_parents=new_parents)\n\ngen = x_xy.build_generator(sys, x_xy.MotionConfig(T=10.0, t_max=1.5, dang_max_free_spherical=0.1, dpos_max=0.1), _compat=True)\n_, xs = gen(jax.random.PRNGKey(1))\nshow_video(sys, xs)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250/250 [00:01<00:00, 147.05it/s]\n
\n
This browser does not support the video tag. Pretty cool, ha? :)
\n
"},{"location":"notebooks/morph_system/#different-anchors-explains-sys_composermorph_system","title":"Different Anchors (explains sys_composer.morph_system)","text":""},{"location":"notebooks/motion_artifact_rejection/","title":"Motion artifact rejection","text":"This example is available as a jupyter notebook here.
And on Google Colab here
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU
. To do this go to Runtime
-> Change runtime type
-> GPU
Otherwise, rendering won't work in Google Colab.
import os\n\ntry:\n import google.colab\n IN_COLAB = True\nexcept:\n IN_COLAB = False\n\nif IN_COLAB:\n os.system(\"pip install --quiet 'x_xy[muj] @ git+https://github.com/SimiPixel/x_xy_v2'\")\n os.system(\"pip install --quiet mediapy\")\n
import x_xy\n# automatically detects colab or not\nx_xy.utils.setup_colab_env()\n\nimport jax\nimport jax.numpy as jnp\n\nimport mediapy as media\n\ndef show_video(sys, xs, **kwargs):\n media.show_video(x_xy.render(sys, [xs[i] for i in range(0, xs.shape(), 4)], camera=\"target\", width=640, height=480, **kwargs), fps=25)\n
knee_xml_str = \"\"\"\n<x_xy model=\"knee_flexible_imus\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body damping=\"5 5 5 25 25 25\" joint=\"free\" name=\"femur\" pos=\"0.5 0.5 0.3\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.05 0.4\" euler=\"0 90 0\" mass=\"1\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"0.2 0 0.05\" pos_max=\"0.35 0 0\" pos_min=\"0.05 0 0\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0.1\" type=\"box\"></geom>\n</body>\n<body damping=\"3\" joint=\"ry\" name=\"tibia\" pos=\"0.4 0 0\">\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<geom dim=\"0.04 0.4\" euler=\"0 90 0\" mass=\"1\" pos=\"0.2 0 0\" type=\"capsule\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.2 0 0.05\" pos_max=\"0.35 0 0\" pos_min=\"0.05 0 0\">\n<geom dim=\"0.05\" type=\"xyz\"></geom>\n<geom color=\"orange\" dim=\"0.05 0.05 0.02\" mass=\"0.1\" type=\"box\"></geom>\n</body>\n<geom dim=\"0.025 0.05 0.2\" mass=\"0\" pos=\"0.45 0 .1\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n\nsys = x_xy.load_sys_from_str(knee_xml_str)\n
media.show_image(x_xy.render(sys, camera=\"target\", height=480, width=640)[0])\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1/1 [00:00<00:00, 14.47it/s]\n
\n
T = 20.0\nT_initial_nomotion = 2.0\n\nconfig = x_xy.MotionConfig(t_min=0.1, t_max=0.75, T=T, ang0_min=0.0, ang0_max=0.0, pos_min=-1.0, pos_max=1.0, dpos_max=0.5)\nconfig = x_xy.join_motionconfigs([config.to_nomotion_config(), config], [T_initial_nomotion])\n\n(X, y), (_, qs, xs, sys_mod) = x_xy.build_generator(sys, config, imu_motion_artifacts=True, dynamic_simulation=True, eager=True, \n aslist=True, seed=1, sizes=1, keep_output_extras=True, imu_motion_artifacts_kwargs=dict(hide_injected_bodies=False))[0]\n
\n/Users/simon/Documents/PYTHON/x_xy_v2/x_xy/algorithms/generator/motion_artifacts.py:80: UserWarning: `sys.links.joint_params` has been set to zero, this might lead to unexpected behaviour unless you use `randomize_joint_params`\n warnings.warn(\n/Users/simon/Documents/PYTHON/x_xy_v2/x_xy/algorithms/generator/base.py:184: UserWarning: `imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`\n warnings.warn(\neager data generation: 1it [00:28, 28.97s/it]\n
\n
show_video(sys_mod, xs, show_floor=False)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 500/500 [00:01<00:00, 315.52it/s]\n
\n
This browser does not support the video tag. sys_frozen = sys_mod.freeze(\"tibia\").freeze(\"femur\")\n\ndef freeze_x(q_obs):\n q_frozen = jnp.concatenate(tuple(q_obs[:, sys_mod.idx_map(\"q\")[name]] for name in [\"_imu1\", \"imu1\", \"_imu2\", \"imu2\"]), axis=-1)\n return jax.vmap(lambda q: x_xy.algorithms.forward_kinematics_transforms(sys_frozen, q)[0])(q_frozen)\n
show_video(sys_frozen, freeze_x(qs))\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 500/500 [00:02<00:00, 172.52it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"notebooks/visualisation/","title":"Visualisation","text":"import ring\nfrom ring import exp\nimport mediapy as media\nimport jax\n
sys_str = \"\"\"\n<x_xy>\n<worldbody>\n<geom dim=\"0.1\" type=\"xyz\"></geom>\n<body joint=\"free\" name=\"seg\" pos=\"0 0 .5\">\n<geom color=\"dustin_exp_blue\" dim=\"0.15 0.075 0.05\" mass=\"0.2\" pos=\"0.03 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu\" pos=\"0.0 0.0 0.03\">\n<geom color=\"dustin_exp_orange\" dim=\"0.05 0.03 0.02\" mass=\"0.1\" type=\"box\"></geom>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n
sys = ring.System.create(sys_str)\n
(X, y), (key, q, x, _) = ring.RCMG(sys, keep_output_extras=True).to_list()[0]\n
\neager data generation: 1it [00:01, 1.58s/it]\n
\n
media.show_video(sys.render(x, width=640, height=480, camera=\"target\", render_every_nth=4), fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1500/1500 [00:08<00:00, 167.58it/s]\n
\n
This browser does not support the video tag. exp_data = exp.load_data(exp_id=\"S_06\", motion_start=\"fast\")\n
exp_data.keys()\n
\ndict_keys(['seg1', 'seg2', 'seg3', 'seg4', 'seg5'])
\n
exp_data[\"seg1\"].keys()\n
\ndict_keys(['imu_flex', 'imu_rigid', 'marker1', 'marker2', 'marker3', 'marker4', 'quat'])
\n
segment = \"seg2\"\nomc_data_sys = {\n \"seg\": {\n \"pos\": exp_data[segment][\"marker1\"],\n \"quat\": exp_data[segment][\"quat\"],\n },\n \"imu\": {\n \"quat\": exp_data[segment][\"quat\"],\n }\n}\n
omc_data_sys\n
\n{'seg': {'pos': Array([[ 0.26832035, 1.1925832 , -0.06244465],\n [ 0.268319 , 1.1925788 , -0.06244366],\n [ 0.26831627, 1.1925731 , -0.06244285],\n ...,\n [ 0.19899347, 1.2143577 , -0.06264979],\n [ 0.19898805, 1.2143621 , -0.06264362],\n [ 0.19897974, 1.214368 , -0.06263389]], dtype=float32),\n 'quat': Array([[ 0.95449424, 0.10144146, -0.01664882, -0.27995223],\n [ 0.95449567, 0.10146226, -0.01665274, -0.27993944],\n [ 0.95449716, 0.10148306, -0.01665665, -0.27992663],\n ...,\n [ 0.9600311 , -0.01755334, 0.02215116, -0.2784628 ],\n [ 0.9600333 , -0.01756094, 0.02225687, -0.27844635],\n [ 0.96003544, -0.01756854, 0.02236257, -0.27842987]], dtype=float32)},\n 'imu': {'quat': Array([[ 0.95449424, 0.10144146, -0.01664882, -0.27995223],\n [ 0.95449567, 0.10146226, -0.01665274, -0.27993944],\n [ 0.95449716, 0.10148306, -0.01665665, -0.27992663],\n ...,\n [ 0.9600311 , -0.01755334, 0.02215116, -0.2784628 ],\n [ 0.9600333 , -0.01756094, 0.02225687, -0.27844635],\n [ 0.96003544, -0.01756854, 0.02236257, -0.27842987]], dtype=float32)}}
\n
x = ring.sim2real.xs_from_raw(sys, omc_data_sys)\n\n# vectorize this function over time\n@jax.vmap\ndef update_position_vector_of_imu(x):\n state = ring.State.create(sys, x=x)\n # populate minimal coordinates `state.q` from maximal coordinates `state.x`\n state = ring.algorithms.inverse_kinematics(sys, state)\n # re-calculate maximal coordiantes `state.x` from minimal coordinates `state.q`\n # this uses the position vector specified in the system (and so the xml file)\n # to produce an offset between IMu and segment geom box\n _, state = ring.algorithms.forward_kinematics(sys, state)\n return state.x\n\nx = update_position_vector_of_imu(x)\n
media.show_video(sys.render(x, width=640, height=480, camera=\"target\", render_every_nth=4), fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1075/1075 [00:07<00:00, 134.95it/s]\n
\n
This browser does not support the video tag. \n
"},{"location":"prism/ss_23_marcel_thomas/notebook/","title":"Notebook","text":"import x_xy\nimport jax\nimport jax.numpy as jnp\nimport jax.random as random\nfrom x_xy.subpkgs.ml import rnno, callbacks, train, load\nfrom x_xy.subpkgs import sim2real, sys_composer\nimport tree_utils\nimport matplotlib.pyplot as plt\nimport mediapy as media\n
three_seg_rigid = r\"\"\"\n<x_xy model=\"three_seg_rigid\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg2\">\n<geom color=\"red\" dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n<body joint=\"rsry\" name=\"seg1\" pos=\"0 0 0\">\n<geom color=\"yellow\" dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"-0.1 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.1 0.0 0.03\">\n<geom color=\"green\" dim=\"0.05 0.01 0.01\" mass=\"2\" pos=\"0 0 0\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rsrz\" name=\"seg3\" pos=\"0.2 0 0\">\n<geom color=\"blue\" dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.1 0.0 0.03\">\n<geom color=\"green\" dim=\"0.05 0.01 0.01\" mass=\"2\" pos=\"0 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n<defaults>\n<geom color=\"1 0.8 0.7 1\" edge_color=\"black\"></geom>\n</defaults>\n</x_xy>\n\"\"\"\n
dustin_exp_xml_seg1 = r\"\"\"\n<x_xy model=\"dustin_exp\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<worldbody>\n<body joint=\"free\" name=\"seg1\">\n<geom dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"-0.1 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg2\">\n<geom dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n<body joint=\"rz\" name=\"seg3\" pos=\"0.2 0 0\">\n<geom dim=\"0.2 0.05 0.05\" mass=\"10\" pos=\"0.1 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\n
# Helper function - Creates an array of values x <- [0, 1] which may be multiplied to another sequence.\ndef motion_amplifier(\n time : float,\n sampling_rate : float,\n key_rigid_phases : jax.Array,\n n_rigid_phases=3,\n rigid_duration_cov=jnp.array([0.02] * 3),\n transition_cov=jnp.array([0.1] * 3)\n) -> jax.Array:\n error_msg = \"motion_amplifier: There must be a variance for each rigid phase!\"\n assert rigid_duration_cov.shape == (n_rigid_phases,) == transition_cov.shape, error_msg\n n_frames = int(time / sampling_rate)\n key_rigid_means, key_rigid_variances, key_slope_down_variances, key_slope_up_variances = random.split(key_rigid_phases, 4)\n\n # Calculate center points of rigid phases\n means = jnp.sort(random.uniform(key_rigid_means, shape=(n_rigid_phases, 1), minval=0, maxval=n_frames).T)\n\n # Calculate durations, which is twice the rigid distance from the center points for each rigid phase.\n rigid_distances = jnp.abs(random.multivariate_normal(\n key_rigid_variances, mean=jnp.zeros_like(means), cov=jnp.diag((rigid_duration_cov * n_frames)**2)))\n\n # Calculate transition durations\n transition_slowdown_durations = jnp.abs(random.multivariate_normal(\n key_slope_down_variances, mean=jnp.zeros_like(means), cov=jnp.diag((transition_cov * n_frames)**2)\n ))\n transition_speedup_durations = jnp.abs(random.multivariate_normal(\n key_slope_up_variances, mean=jnp.zeros_like(means), cov=jnp.diag((transition_cov * n_frames)**2)\n ))\n\n # Phase start and end points\n rigid_starts = (means - rigid_distances).astype(int).flatten()\n rigid_ends = (means + rigid_distances).astype(int).flatten()\n starts_slowing = (means - rigid_distances - transition_slowdown_durations).astype(int).flatten()\n ends_moving = (means + rigid_distances + transition_speedup_durations).astype(int).flatten()\n\n # Create masks\n def create_mask(start, end):\n nonlocal n_frames\n return jnp.where(jnp.arange(n_frames) < start, 1, 0) + jnp.where(jnp.arange(n_frames) >= end, 1, 0)\n\n mask = jax.vmap(create_mask)\n rigid_mask = jnp.prod(mask(rigid_starts, rigid_ends), axis=0)\n slowdown_masks = mask(starts_slowing, rigid_starts).astype(float)\n speedup_masks = mask(rigid_ends, ends_moving).astype(float)\n\n # We have to define an inline function in order to make this code JIT-able\n def linsp(mask, start, end, begin_val, carry_fun):\n range = end - start\n def true_fun(carry, x): return (carry_fun(carry, range), 1 - carry)\n def false_fun(carry, x): return (carry, x)\n def f(carry, x): return jax.lax.cond(\n x == 0, true_fun, false_fun, *(carry, x))\n return jax.lax.scan(f, begin_val, mask)[1]\n\n linsp_desc = jax.vmap(lambda m, s1, s2: linsp( m, s1, s2, 0.0, lambda carry, range: carry + 1/range))\n slowdown_mask = jnp.prod(linsp_desc(slowdown_masks, starts_slowing, rigid_starts), axis=0)\n\n linsp_asc = jax.vmap(lambda m, s1, s2: linsp(m, s1, s2, 1.0, lambda carry, range: carry - 1/range))\n speedup_mask = jnp.prod(linsp_asc(speedup_masks, rigid_ends, ends_moving), axis=0)\n\n return jnp.min(jnp.stack([rigid_mask, slowdown_mask, speedup_mask]), axis=0)\n
# Random generator: Uses the motion_amplifier to dampen/null the randomly generated angles.\ndef random_angles_with_rigid_phases_over_time(\n key_t,\n key_ang,\n dt,\n key_rigid_phases,\n n_rigid_phases=3,\n rigid_duration_cov=jnp.array([0.02] * 3),\n transition_cov=jnp.array([0.1] * 3),\n config: x_xy.algorithms.MotionConfig=x_xy.algorithms.MotionConfig()\n) -> jax.Array:\n\n mask = motion_amplifier(\n config.T,\n dt,\n key_rigid_phases,\n n_rigid_phases,\n rigid_duration_cov,\n transition_cov)\n\n qs = x_xy.algorithms.random_angle_over_time(\n key_t=key_t,\n key_ang=key_ang,\n ANG_0=config.ang0_max,\n dang_min=config.dang_min,\n dang_max=config.dang_max,\n delta_ang_min=config.delta_ang_min,\n delta_ang_max=config.delta_ang_max,\n t_min=config.t_min,\n t_max=config.t_max,\n T=config.T,\n Ts=dt,\n randomized_interpolation=config.randomized_interpolation_angle,\n range_of_motion=config.range_of_motion_hinge,\n range_of_motion_method=config.range_of_motion_hinge_method\n )\n\n # derivate qs\n qs_diff = jnp.diff(qs, axis=0)\n\n # mulitply with motion amplifier\n qs_diff = qs_diff * mask[:-1]\n\n # integrate qs_diff\n qs_rigid_phases = jnp.concatenate((qs[0:1], jnp.cumsum(qs_diff, axis=0)))\n return qs_rigid_phases\n
BEST_RUN = (1, jnp.array([0.02]), jnp.array([0.1]))\nMANY_TINY_STOPS = (30, jnp.array([0.001] * 30), jnp.array([0.0001] * 30))\n##################################################################################\n# Define your own problem configuration here :) #\n\nPROBLEM = BEST_RUN # <- Change this assignment to use it.\n##################################################################################\n\ndef define_joints():\n def _draw_sometimes_rigid(\n config: x_xy.algorithms.MotionConfig, key_t: jax.Array, key_value: jax.Array, dt : float, joint_params : jax.Array\n ) -> jax.Array:\n key_t, key_rigid_phases = jax.random.split(key_t)\n return random_angles_with_rigid_phases_over_time(\n key_t=key_t,\n key_ang=key_value,\n dt=dt,\n key_rigid_phases=key_rigid_phases,\n n_rigid_phases=PROBLEM[0],\n rigid_duration_cov=PROBLEM[1],\n transition_cov=PROBLEM[2],\n config=config\n )\n\n def _rxyz_transform(q, _, axis):\n q = jnp.squeeze(q)\n rot = x_xy.maths.quat_rot_axis(axis, q)\n return x_xy.base.Transform.create(rot=rot)\n\n rsrx_joint = x_xy.algorithms.JointModel(\n lambda q, _: _rxyz_transform(q, _, jnp.array([1.0, 0, 0])), [None], rcmg_draw_fn=_draw_sometimes_rigid\n )\n rsry_joint = x_xy.algorithms.JointModel(\n lambda q, _: _rxyz_transform(q, _, jnp.array([0, 1.0, 0])), [None], rcmg_draw_fn=_draw_sometimes_rigid\n )\n rsrz_joint = x_xy.algorithms.JointModel(\n lambda q, _: _rxyz_transform(q, _, jnp.array([0, 0, 1.0])), [None], rcmg_draw_fn=_draw_sometimes_rigid\n )\n try:\n x_xy.algorithms.register_new_joint_type(\"rsrx\", rsrx_joint, 1)\n x_xy.algorithms.register_new_joint_type(\"rsry\", rsry_joint, 1)\n x_xy.algorithms.register_new_joint_type(\"rsrz\", rsrz_joint, 1)\n except AssertionError:\n print(\"Warning: Joints have already been registered!\")\n\ndefine_joints()\n
Note: it is also possible to support multiple problems at the same time, by implementing them as seperate joint types, or by injecting the x_xy.algorithms.MotionConfig
class e.g. by inheritance.
After we defined the joint type, we can load the system:
sys_rigid = x_xy.io.load_sys_from_str(three_seg_rigid)\nsys_inference = x_xy.io.load_sys_from_str(dustin_exp_xml_seg1)\n
def finalize_fn_imu_data(key, q, x, sys):\n imu_seg_attachment = {\"imu1\": \"seg1\", \"imu2\": \"seg3\"}\n\n X = {}\n for imu, seg in imu_seg_attachment.items():\n key, consume = jax.random.split(key)\n X[seg] = x_xy.algorithms.imu(\n x.take(sys.name_to_idx(imu), 1), sys.gravity, sys.dt, consume, True\n )\n return X\n\n\ndef finalize_fn_rel_pose_data(key, _, x, sys):\n y = x_xy.algorithms.rel_pose(sys_scan=sys_inference, xs=x, sys_xs=sys)\n return y\n\ndef finalize_fn(key, q, x, sys):\n X = finalize_fn_imu_data(key, q, x, sys)\n # Since no IMU is attached to seg2, we need to provide dummy data.\n X[\"seg2\"] = tree_utils.tree_zeros_like(X[\"seg1\"])\n y = finalize_fn_rel_pose_data(key, q, x, sys)\n return X, y\n
The generated data comes is returned in the tuple \\((\\mathbf{X}, \\mathbf{y})\\), with \\(\\mathbf{X}\\) being the generated IMU accelorometer and gyroscope data and \\(\\mathbf{y}\\) the orientation of each segment, in form of a unit quaternion.
def setup_fn_seg2(key, sys: x_xy.base.System) -> x_xy.base.System:\n def replace_pos(transforms, new_pos, name: str):\n i = sys.name_to_idx(name)\n return transforms.index_set(i, transforms[i].replace(pos=new_pos))\n\n def draw_pos_uniform(key, pos_min, pos_max):\n key, c1, c2, c3 = jax.random.split(key, num=4)\n pos = jnp.array(\n [\n jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),\n jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),\n jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),\n ]\n )\n return key, pos\n\n ts = sys.links.transform1\n\n # seg1 relative to seg2\n key, pos = draw_pos_uniform(key, [-0.3, -0.02, -0.02], [-0.05, 0.02, 0.02])\n ts = replace_pos(ts, pos, \"seg1\")\n\n # imu1 relative to seg1\n key, pos = draw_pos_uniform(\n key, [-0.25, -0.05, -0.05], [-0.05, 0.05, 0.05])\n ts = replace_pos(ts, pos, \"imu1\")\n\n # seg3 relative to seg2\n key, pos = draw_pos_uniform(key, [0.05, -0.02, -0.02], [0.3, 0.02, 0.02])\n ts = replace_pos(ts, pos, \"seg3\")\n\n # imu2 relative to seg2\n key, pos = draw_pos_uniform(key, [0.05, -0.05, -0.05], [0.25, 0.05, 0.05])\n ts = replace_pos(ts, pos, \"imu2\")\n\n return sys.replace(links=sys.links.replace(transform1=ts))\n
With this, we can now train the model: We first define the batch size and number of epochs. For good results, a relatively large number of epochs is required, as the mean average angle error in training converges relatively late in training. Then we plug together the setup- and finalize functions in a generator function, which will provide the batched training data. A logger might also be added, such as a neptune logger. When using neptune, the environment-variables NEPTUNE_TOKEN
and NEPTUNE_PROJECT
must be set accordingly.
TRAINING_BATCH_SIZE = 80\nEPOCHS = 1500\nparams_path = \"parameters.pickle\"\nKEY_GEN = random.PRNGKey(1)\nKEY_NETWORK = random.PRNGKey(1)\n\ngen = x_xy.algorithms.build_generator(sys_rigid, x_xy.algorithms.MotionConfig(), setup_fn_seg2, finalize_fn)\ngen = x_xy.algorithms.batch_generators_lazy(gen, TRAINING_BATCH_SIZE)\n\n# Set 'upload' to True if a logger is attached.\nsave_params = callbacks.SaveParamsTrainingLoopCallback(params_path, upload=False) \n\nloggers = []\n# loggers.append(NeptuneLogger()) # You may add loggers here, e.g. a Neptune Logger\n\nnetwork = rnno.make_rnno(sys_inference)\n
WARNING! Executing this code can take a long time (due to the very high number of epochs) and will probably take up a huge portion of your memory. If you run this code on a GPU, a batch size of 80 takes more than 50 GB of VRAM, so if the execution fails, it might be because of missing GPU memory. To circumvent this, the batch size can be decreased, however, the results will suffer from that.
train(gen, EPOCHS, network, loggers=loggers, callbacks=[save_params], key_generator=KEY_GEN, key_network=KEY_NETWORK)\n
def finalize_fn_inference(key, q, x, sys):\n X = finalize_fn_imu_data(key, q, x, sys)\n y = finalize_fn_rel_pose_data(key, q, x, sys)\n return X, y, x\n\n\ndef generate_inference_data(sys, config: x_xy.algorithms.MotionConfig, seed=random.PRNGKey(1,)):\n generator = x_xy.algorithms.build_generator(sys, config, finalize_fn=finalize_fn_inference)\n X, y, xs = generator(seed)\n return X, y, xs\n
To control the data generated, the MotionConfig data is used. It contains all necessary information about the to-be-generated data series, e.g. time (config.T
), except for the sampling rate, which is stored in the system object (<sys>.dt
) and set in the XML-definition. The finalize function and its return values are similiar to the training finilaize function, however, an addtitional \\(\\mathbf{xs}\\) is returned, containing the actual position and rotation. This can be used for rendering purposes later. Also, the data is not batched, as we currently are only interested in one time series.
config = x_xy.algorithms.MotionConfig()\nprint(f\"Generating data for a time series of {config.T} seconds, with a sampling rate of {1/sys_inference.dt} Hz.\")\n\n# If you are unhappy with your data series, you can alter this seed:\nseed = random.PRNGKey(1337,)\n\nX, y, xs = generate_inference_data(sys_rigid, config, seed)\n\n# Add dummy IMU data for segment 2 (which has no IMU attached)\nX[\"seg2\"] = tree_utils.tree_zeros_like(X[\"seg1\"])\n
\nGenerating data for a time series of 60.0 seconds, with a sampling rate of 100.0 Hz.\n
\n
params = load(\"parameters.pickle\")\n
Finally, we have everything we need to do inference! Let's see how our network performs...
# Run prediction:\nX_3d = tree_utils.to_3d_if_2d(X, strict=True)\ninitial_params, state = network.init(random.PRNGKey(1,), X_3d)\nyhat, _ = network.apply(params, tree_utils.add_batch_dim(state), X_3d)\nyhat = tree_utils.to_2d_if_3d(yhat, strict=True)\n\n# Plot prediction:\ndef plot_segment(segment : str, axis : str, ax):\n axis_idx = \"xyz\".index(axis)\n euler_angles_hat_seg2 = jnp.rad2deg(x_xy.maths.quat_to_euler(yhat[segment])[:,axis_idx])\n euler_angles_seg2 = jnp.rad2deg(x_xy.maths.quat_to_euler(y[segment])[:,axis_idx])\n ax.plot(euler_angles_hat_seg2, label=\"prediction\")\n ax.set_ylim((-180, 180))\n ax.set_title(f\"{segment} ({axis}-axis)\")\n ax.plot(euler_angles_seg2, label=\"truth\")\n ax.set_xlabel(\"time [s]\")\n ax.set_ylabel(\"euler angles [deg]\")\n ax.legend()\n print(f\"{segment}: medium absolute error {jnp.average(jnp.abs(euler_angles_hat_seg2 - euler_angles_seg2))} deg\")\n\nfig, axs = plt.subplots(ncols=2, figsize=(10, 4))\nplot_segment(\"seg2\", 'y', axs[0])\nplot_segment(\"seg3\", 'z', axs[1])\nplt.show()\n
\nseg2: medium absolute error 0.524849534034729 deg\nseg3: medium absolute error 0.5137953162193298 deg\n
\n
Let's also render a video of the prediction and the truth:
# Extract translations from data-generating system...\ntranslations, rotations = sim2real.unzip_xs(sys_inference, sim2real.match_xs(sys_inference, xs, sys_rigid))\nyhat_inv = jax.tree_map(lambda quat: x_xy.maths.quat_inv(quat), yhat) \n\n# ... swap rotations with predicted ones...\nrotations_hat = [] \nfor i, name in enumerate(sys_inference.link_names):\n if name in yhat_inv:\n rotations_name = x_xy.Transform.create(rot=yhat_inv[name])\n else:\n rotations_name = rotations.take(i, axis=1)\n rotations_hat.append(rotations_name)\n\n# ... and plug the positions and rotations back together.\nrotations_hat = rotations_hat[0].batch(*rotations_hat[1:]).transpose((1, 0, 2))\nxs_hat = sim2real.zip_xs(sys_inference, translations, rotations_hat)\n\n# Create combined system that shall be rendered and its transforms\nsys_render = sys_composer.inject_system(sys_rigid, sys_inference.add_prefix_suffix(suffix=\"_hat\"))\nxs_render = x_xy.Transform.concatenate(xs, xs_hat, axis=1)\n\n# Render prediction and truth:\nframes = x_xy.render(sys_render, [xs_render[i] for i in range(xs_render.shape(axis=0))], camera='target')\nmedia.show_video([frame[..., :3] for frame in frames], fps=25)\n
\nRendering frames..: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 6000/6000 [00:16<00:00, 374.76it/s]\n
\n
This browser does not support the video tag."},{"location":"prism/ss_23_marcel_thomas/notebook/#training-the-rnno-with-rigid-phases-prism-ss2023","title":"Training the RNNO with rigid phases (PRISM SS2023)","text":"In this notebook, we define a custom hinge joint, which is configured to generate pauses (no movement) inside the generated data series. We use this joint to train the RNNO and perfrom inference with the generated parameters.
"},{"location":"prism/ss_23_marcel_thomas/notebook/#defining-the-system","title":"Defining the System","text":"A system is defined in an XML structure. To read a system, an XML file may be used. It is also possible to define the system inline by using a string in XML-syntax.In the following, we define two three-segment chains:
"},{"location":"prism/ss_23_marcel_thomas/notebook/#registering-the-joint-axis","title":"Registering the joint axis","text":"For this scenario, we define two systems: One for generating data with rigid phases and one for inference. To generate the random data with rigid phases, we first have to register a joint type, that allows for the creation of such data. We call this joint 'rsr\\<x|y|z>', a hinge joint that produces random sometimes rigid data, and turns around the respective axis \\(x\\), \\(y\\) or \\(z\\) in its frame. </x|y|z>
"},{"location":"prism/ss_23_marcel_thomas/notebook/#generating-random-data","title":"Generating random data","text":"The random data is generated by the following functions:
"},{"location":"prism/ss_23_marcel_thomas/notebook/#defining-the-random-joint-function","title":"Defining the random joint function","text":"First of all, we have to define our problem. This means, parameterzing the random function. Two possible scenarios are implemented below: \"BEST_RUN\" and \"MANY_TINY_STOPS\", both of which achieved adequate results. The problems are defined as \\(P=(N, \\mathbf{\\sigma}_{r}, \\mathbf{\\sigma}_{tr})\\), with \\(N\\) being the number of rigid phases, \\(\\mathbf{\\sigma}_r\\) the covariance used for calculating the length of each rigid phase and \\(\\mathbf{\\sigma}_{tr}\\) for the length of each transition phase respectively. It also holds that \\(\\mathbf{\\sigma}_r, \\mathbf{\\sigma}_{tr} \\in \\mathbb{R}^N\\), with each entry being the variance for exactly one rigid phase.
"},{"location":"prism/ss_23_marcel_thomas/notebook/#generating-raw-data","title":"Generating raw data","text":"For both training and inference, we first need a set of raw data. In our example, sys_rigid
is used to generate the problem-specific data for each IMU. This data will be used for training and later by sys_inference
to estimate the position and orientation of seg2
, which has no IMU attached.
Before we begin with the actual training, we first define a setup function. This is called before training on each time series. The function below alters the length of segments and the position of the IMUs of the system, to simulate inaccuracies, e.g. when dealing with experimental data.
"},{"location":"prism/ss_23_marcel_thomas/notebook/#infering-data","title":"Infering data","text":""},{"location":"prism/ss_23_marcel_thomas/notebook/#inference","title":"Inference","text":"To do inference, we first need to load the parameters (weights) of our model.
"},{"location":"prism/ss_23_moritz/notebook/","title":"Notebook","text":"import jax\nimport jax.numpy as jnp\nimport tree_utils\nfrom jax.nn import softmax\nimport matplotlib.pyplot as plt\nimport mediapy\n\nimport x_xy\nfrom x_xy.subpkgs import ml, sim2real, sys_composer\n
Set the batch size and number of training episodes according to the available hardware.
BATCHSIZE = 32\nNUM_TRAINING_EPISODES = 1500\n
sys_str = r\"\"\"\n<x_xy model=\"three_segment_kinematic_chain\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<defaults>\n<geom color=\"orange\"></geom>\n</defaults>\n<worldbody>\n<body joint=\"free\" name=\"seg2\" pos=\"0 0 2\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg1\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu1\" pos=\"-0.5 0 0.125\">\n<geom color=\"red\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n<body joint=\"rz\" name=\"seg3\" pos=\"1 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"0.1\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"frozen\" name=\"imu2\" pos=\"0.5 0 -0.125\">\n<geom color=\"red\" dim=\"0.2 0.2 0.05\" mass=\"0.05\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\nsys = x_xy.io.load_sys_from_str(sys_str)\n
dustin_exp_xml_seg1 = r\"\"\"\n<x_xy model=\"dustin_exp\">\n<options dt=\"0.01\" gravity=\"0 0 9.81\"></options>\n<defaults>\n<geom color=\"white\"></geom>\n</defaults>\n<worldbody>\n<body joint=\"free\" name=\"seg1\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"-0.5 0 0\" type=\"box\"></geom>\n<body joint=\"ry\" name=\"seg2\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n<body joint=\"rz\" name=\"seg3\" pos=\"0.2 0 0\">\n<geom dim=\"1 0.25 0.2\" mass=\"10\" pos=\"0.5 0 0\" type=\"box\"></geom>\n</body>\n</body>\n</body>\n</worldbody>\n</x_xy>\n\"\"\"\nsys_inference = x_xy.io.load_sys_from_str(dustin_exp_xml_seg1)\n
def finalise_fn(key: jax.Array, q: jax.Array, xs: x_xy.Transform, sys: x_xy.System):\n def xs_by_name(name: str):\n return xs.take(sys.name_to_idx(name), axis=1)\n\n key, *consume = jax.random.split(key, 3)\n\n # the input X to our RNNo is the IMU data of segments 1 and 3\n X = {\n \"seg1\": x_xy.imu(xs_by_name(\"imu1\"), sys.gravity, sys.dt, consume[0], True),\n \"seg3\": x_xy.imu(xs_by_name(\"imu2\"), sys.gravity, sys.dt, consume[1], True),\n }\n\n # seg2 has no IMU, but we still need to make an entry in our X\n X[\"seg2\"] = tree_utils.tree_zeros_like(X[\"seg1\"])\n\n # the output of the RNNo is the estimated relative poses of our segments\n y = x_xy.algorithms.rel_pose(sys_scan=sys_inference, xs=xs, sys_xs=sys)\n\n return X, y\n\nconfig = x_xy.algorithms.MotionConfig(dpos_max=0.3, ang0_min=0.0, ang0_max=0.0)\n\ngen = x_xy.build_generator(sys, config, finalize_fn=finalise_fn)\ngen = x_xy.batch_generator(gen, BATCHSIZE)\n
def make_loss_fn(beta):\n def metric_fn(q, q_hat):\n return x_xy.maths.angle_error(q, q_hat) ** 2\n\n if beta is not None:\n\n def loss_fn(q, q_hat):\n # q.shape == q_hat.shape == (1000, 4)\n angles = metric_fn(q, q_hat)\n\n factors = angles.shape[-1] * softmax(\n beta * jax.lax.stop_gradient(angles), axis=-1\n )\n\n errors = factors * angles\n\n return errors\n\n else:\n loss_fn = metric_fn\n\n return loss_fn\n
beta
determines the strength of our weighting: the larger beta, the more relative weight we put on the larger errors, while beta = 0.0
makes the scaling factors uniform one and gives us back our unweighted errors. Alternatively beta = None
bypasses the scaling altogether.
beta = 1.0\n
rnno = ml.make_rnno(sys_inference)\n\nloss_fn = make_loss_fn(beta)\n\nsave_params = ml.callbacks.SaveParamsTrainingLoopCallback(\n \"parameters.pickle\", upload=False\n)\n\nml.train(gen, NUM_TRAINING_EPISODES, rnno, callbacks=[save_params], loss_fn=loss_fn)\n
To visualise our network, we can render it using mediapy. First we generate some motion data.
gen = x_xy.build_generator(sys, config)\n\nkey = jax.random.PRNGKey(1)\n\nq, xs = gen(key)\n
We need to again bring the motion data in the correct form for our RNNo and can then run inference of the generated data.
params = ml.load(\"parameters.pickle\")\n\nX, y = finalise_fn(key, q, xs, sys)\n\nX = tree_utils.add_batch_dim(X)\n\n_, state = rnno.init(key, X)\n\nstate = tree_utils.add_batch_dim(state)\n\ny_hat, _ = rnno.apply(params, state, X)\ny_hat = tree_utils.to_2d_if_3d(y_hat, strict=True)\n
First we want to plot the angle error for both segment 2 and segment 3 over time.
y[\"seg2\"][:10]\n
y_hat[\"seg2\"]\n
fig, ax = plt.subplots()\n\nangle_error2 = jnp.rad2deg(x_xy.maths.angle_error(y[\"seg2\"], y_hat[\"seg2\"]))\nangle_error3 = jnp.rad2deg(x_xy.maths.angle_error(y[\"seg3\"], y_hat[\"seg3\"]))\n\nT = jnp.arange(angle_error2.size) * sys_inference.dt\n\nax.plot(T, angle_error2, label=\"seg2\")\nax.plot(T, angle_error3, label=\"seg3\")\n\nax.set_xlabel(\"time [s]\")\nax.set_ylabel(\"abs. angle error [deg]\")\n\nax.legend()\n\nplt.show()\n
Next we have to create an xs_hat
of the estimated orientations, so that we can render them.
# Extract translations from data-generating system...\ntranslations, rotations = sim2real.unzip_xs(\n sys_inference, sim2real.match_xs(sys_inference, xs, sys)\n)\n\ny_hat_inv = jax.tree_map(lambda quat: x_xy.maths.quat_inv(quat), y_hat) \n\n# ... swap rotations with predicted ones...\nrotations_hat = [] \nfor i, name in enumerate(sys_inference.link_names):\n if name in y_hat_inv:\n rotations_name = x_xy.Transform.create(rot=y_hat_inv[name])\n else:\n rotations_name = rotations.take(i, axis=1)\n rotations_hat.append(rotations_name)\n\n# ... and plug the positions and rotations back together.\nrotations_hat = rotations_hat[0].batch(*rotations_hat[1:]).transpose((1, 0, 2))\nxs_hat = sim2real.zip_xs(sys_inference, translations, rotations_hat)\n\n# Create combined system that shall be rendered and its transforms\nsys_render = sys_composer.inject_system(sys, sys_inference.add_prefix_suffix(suffix=\"_hat\"))\nxs_render = x_xy.Transform.concatenate(xs, xs_hat, axis=1)\n
Now we can render both the predicted system (in white) as well as the real system (in orange).
xs_list = [xs_render[i] for i in range(xs_render.shape())]\n\nframes = x_xy.render(sys_render, xs_list, camera=\"targetfar\")\nmediapy.show_video([frame[..., :3] for frame in frames], fps=int(1 / sys.dt))\n
\n
"},{"location":"prism/ss_23_moritz/notebook/#training-the-rnno-with-a-custom-loss-function","title":"Training the RNNo with a custom loss function","text":"This notebook showcases how train an RNNo network with a custom loss function rather than the default mean-reduces angle error. This is showcased by scaling the error by a softmax over the time axis, which puts more weight on the time intervals with a higher deviation compared to ones with lower deviation.
"},{"location":"prism/ss_23_moritz/notebook/#defining-the-systems","title":"Defining the systems","text":"We use two separate systems, both parsed from XML strings: one for training (sys
) and one for inference (dustin_sys
).
Our motion data will be automatically generated using a Generator
, which can be customised using an MotionConfig
. The Generator
will generate data for both q
, that is the state of all the joint angles in the system, as well as xs
, which describes the orientations of all the links in the system. To use this data for training our RNNo, we first have to bring it into the correct form using a finalise_fn
.
To customise the loss function of the RNNo, we transform the error values before they are averaged. The input to our loss function will be both \\(q\\), the real joint state, as well as \\(\\hat{q}\\), the joint space estimated by our RNNo. q
and q_hat
will both be jax.Array
s of shape (T_tbp, 4)
, where the first axis is slice over time (of our TBPTT length) and the second axis are the 4 components of a quaternion.
In this notebook we want to change the relative weightings of the errors at different times using a softmax function in order to put more weight on larger errors. First we convert the errors from quaterions to angles. Then we scale each error angle by a factor, calculated from a softmax over the angles. The calculation of the factors includes a call to jax.lax.stop_gradient
to make it so our gradients are only from the errors themselves, not the factors as well.