Skip to content

Commit

Permalink
v0.0.3; adds more methods and interactive viewer
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-bachhuber committed Dec 8, 2024
1 parent 486cde1 commit ade6e64
Show file tree
Hide file tree
Showing 33 changed files with 1,780 additions and 759 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.DS_Store
__pycache__
testing.ipynb
*.egg-info
*.egg-info
hidden
Binary file added conversion/0x643a580673fc3f4.pickle
Binary file not shown.
Binary file added conversion/0xe58c20067500b83.pickle
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
rnn_layers=[800] * 3,
linear_layers=[400] * 1,
act_fn_rnn=lambda X: X,
params="/Users/simon/Downloads/0x643a580673fc3f4.pickle",
params="/Users/simon/Documents/PYTHON/imt/to_onnx/0x643a580673fc3f4.pickle",
celltype="gru",
scale_X=False,
).unwrapped
Expand All @@ -35,7 +35,7 @@ def timestep(a1, a2, g1, g2, state_tm1):

a = jnp.ones((3,), dtype=np.float32)

filename = "relOri-1D2D3D-100Hz-v0.onnx"
filename = "rnno-rO-100Hz-v0.onnx"

if not os.path.exists(filename):
ring.ml.ml_utils.to_onnx(
Expand Down
104 changes: 104 additions & 0 deletions conversion/conversion-rnno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import time

import jax.flatten_util
import jax.numpy as jnp
import numpy as np
import onnxruntime as ort
import ring

X = jnp.zeros((1, 2, 6))


net = ring.ml.RING(
params="/Users/simon/Documents/PYTHON/imt/to_onnx/0xe58c20067500b83.pickle",
celltype="gru",
lam=(-1, 0),
layernorm=True,
forward_factory=ring.ml.rnno_v1.rnno_v1_forward_factory,
rnn_layers=[800] * 3,
linear_layers=[400] * 1,
act_fn_rnn=lambda X: X,
output_dim=8,
)
net = ring.ml.base.NoGraph_FilterWrapper(net, quat_normalize=True)


_, state0 = net.init(X=X)
state_flat, unflatten = jax.flatten_util.ravel_pytree(state0)


def timestep(a1, a2, g1, g2, state_tm1):
grav, pi = jnp.array(9.81), jnp.array(2.2)
X = jnp.concatenate(
(
jnp.concatenate((a1 / grav, g1 / pi))[None],
jnp.concatenate((a2 / grav, g2 / pi))[None],
)
)[None]
yhat, state = net.apply(X=X, state=unflatten(state_tm1))
return yhat[0], jax.flatten_util.ravel_pytree(state)[0]


a = jnp.ones((3,), dtype=np.float32)

filename = "rnno-100Hz-v0.onnx"

if not os.path.exists(filename):
ring.ml.ml_utils.to_onnx(
timestep,
filename,
a,
a,
a,
a,
state_flat,
in_args_names=[
"acc1 (3,) [m/s^2]",
"acc2 (3,) [m/s^2]",
"gyr1 (3,) [rad/s]",
"gyr2 (3,) [rad/s]",
"previous_state (2400,); init with zeros",
],
out_args_names=[
"quats (2, 4); incl-body1-to-eps, body2-to-body1",
"next_state",
],
validate=True,
)


# TIMING TESTING #
def time_f(f, N: int = 1000):
# maybe JIT
f()
t1 = time.time()
for _ in range(N):
f()
print(f"One executation of `f` took: {((time.time() - t1) / N) * 1000}ms")


jit_timestep = jax.jit(timestep)
print("JAX version")
time_f(lambda: jit_timestep(a, a, a, a, state_flat))

a = np.array(a)
state_flat = np.array(state_flat)
session = ort.InferenceSession(filename)


def onnx_timestep():
session.run(
None,
{
"acc1 (3,) [m/s^2]": a,
"acc2 (3,) [m/s^2]": a,
"gyr1 (3,) [rad/s]": a,
"gyr2 (3,) [rad/s]": a,
"previous_state (2400,); init with zeros": state_flat,
},
)


print("ONNX version")
time_f(onnx_timestep)
27 changes: 27 additions & 0 deletions conversion/conversion_ring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import haiku as hk
import jax.numpy as jnp
import ring
import tree_utils


@hk.without_apply_rng
@hk.transform_with_state
def f(X):
X, prev_message_p, prev_mailbox_i = X
prev_state = hk.get_state("inner_cell_state", [2, 400])
X = tree_utils.batch_concat_acme((X, prev_message_p, prev_mailbox_i), 0)
output, next_state = ring.ml.ringnet.StackedRNNCell("gru", 400, 2, True)(
X, prev_state
)
hk.set_state("inner_cell_state", next_state)
next_message_i = hk.nets.MLP([400, 200])(next_state[-1])
output = hk.nets.MLP([400, 4])(output)
output = output / jnp.linalg.norm(output, axis=-1, keepdims=True)
return output, next_message_i


params = ring.utils.pickle_load(
"~/Documents/PYTHON/ring/src/ring/ml/params/0x13e3518065c21cd8.pickle"
)
X = (jnp.zeros((10,)), jnp.zeros((200,)), jnp.zeros((200,)))
f.apply(params, {"~": {"inner_cell_state": jnp.zeros((2, 400))}}, X)
204 changes: 0 additions & 204 deletions examples/TK.ipynb

This file was deleted.

244 changes: 0 additions & 244 deletions examples/imt_lower_extremities.ipynb

This file was deleted.

296 changes: 296 additions & 0 deletions examples/lower_extremities.ipynb

Large diffs are not rendered by default.

169 changes: 169 additions & 0 deletions examples/lower_extremities_2.ipynb

Large diffs are not rendered by default.

237 changes: 237 additions & 0 deletions media/make_gait_video.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "imt-imt"
version = "0.0.2"
version = "0.0.3"
authors = [
{ name="Simon Bachhuber", email="[email protected]" },
]
Expand All @@ -19,7 +19,8 @@ classifiers = [
dependencies = [
"qmt",
"onnxruntime",
"dm-tree"
"dm-tree",
"scipy",
]

[project.urls]
Expand Down
32 changes: 23 additions & 9 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
![gait-example](media/gait_demo.gif)

# High-level Interface for Inertial Motion Tracking
# Plug-and-Play Inertial Motion Tracking

This python package combines many well-established methods to provide a unified interface applicable to diverse inertial motion tracking tasks.

Plug-and-play solutions for standard use-cases are provided, such as
- knee joint angle tracking
- hip joint angle tracking
- arm tracking
- lower extremities / gait tracking (see `/examples/lower_extremities.ipynb`)
- full-body motion capture

Most methods can be applied both online, allowing for real-time motion tracking, as well as offline.

## Installation

```pip install git+https://github.com/simon-bachhuber/imt.git```

# Example for three-segment KC
## Example for three-segment KC

```python
import imt
Expand All @@ -15,22 +26,25 @@ import numpy as np
# Define a graph with one body connecting to the world/earth (0) and two child bodies (1 and 2)
graph = [-1, 0, 0]

# Define the solutions that are used for solving the relative orientation subproblems in the graph
solutions = [
# Define the methods that are used for solving the relative orientation subproblems in the graph
methods = [
# use `vqf` because this body connects to earth, so there is no constraint to exploit
imt.solutions.VQF_Solution(),
imt.methods.VQF(),
# let's assume there is a 1-DOF joint between body '1' and body '0'
imt.solutions.QMT_HeadingConstraintSolution(dof=1),
imt.methods.HeadCor(dof=1),
# let's assume we don't know how many DOFs there are between body '2' and body '0', so we
# will use a general-purpose solution (but which will be less accurate)
imt.solutions.Online_RelOri_1D2D3D_Solution()
# will use a general-purpose method (but which will be less accurate)
imt.methods.RNNO()
]
# We can also let methods be `None` then a set of default methods will be determined auto-
# matically based on the graph
methods = None

# The sampling rate of the IMU data
Ts = 0.01 # Sampling time (100 Hz)

# Initialize the solver
solver = imt.Solver(graph, problem, Ts)
solver = imt.Solver(graph, methods, Ts)

# Define IMU data for the bodies (non-batched)
imu_data = {
Expand Down
4 changes: 3 additions & 1 deletion src/imt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from . import _solutions as solutions
from . import _graph
from . import methods
from . import utils
from ._solver import Solver
Loading

0 comments on commit ade6e64

Please sign in to comment.