-
Are there any native utilities for loading datasets / TensorBoard? Processing data:
TensorBoard:
Would love to help out with porting over some of these utilities to JAX in a functional manner. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Conclusion: costs of rolling own data and logging solution > benefits; the most standard solution would probably be to use tools from the Google ecosystem (flax.metrics.tensorboard, tf.data) as evidenced by the Flax examples. It is a bit unfortunate - for instance, mapping tf.data.Dataset (https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map) seems to be a natural use case for dropping in JAX transformations instead of resorting to TF operations. Sounds familiar! EDIT: See @jheek's answer. |
Beta Was this translation helpful? Give feedback.
-
There are some issues with a "JAX native" data loading pipeline. At its core tf.data is like a scheduler with buffers and tasks that run in parallel (map is not vectorizing like jax.vmap but instead parallelising over CPU threads). PyTorch has the same issue. It provides a thin wrapper around multiprocessing which is just another library for scheduling tasks into a pool of threads/processes but PyTorch itself doesn't know how to parse a JPEG. The big difference is that TF embeds preprocessing into the TF graph so it's more seamless but less modular compared to PyTorch. You can also use PyTorch data loaders with JAX. In the end all these pipelines produce NumPy buffers which can be used with JAX without any overhead. PyTorch data loaders could also uses JAX similair to how you can use PyTorch. For example after decoding and cropping an image you could call some JAX op to do some further preprocessing on the image. You could use things like For Tensorboard again there isn't much JAX can do here. Tensorboard just writes numpy buffers to a file in a ProtoBuf encoding. JAX doesn't have ops to do IO or to encode ProtoBufs. I don't see this as a big issue though. IO and encoding/decoding are very different from the computational ops JAX support. Mixing them together makes it much harder to reason about JAX and would move away from the modular and functional approach that it uses now. |
Beta Was this translation helpful? Give feedback.
There are some issues with a "JAX native" data loading pipeline. At its core tf.data is like a scheduler with buffers and tasks that run in parallel (map is not vectorizing like jax.vmap but instead parallelising over CPU threads).
Secondly, JAX doesn't support dynamic shapes and it isn't trivial to handle things like JPEG, audio, video formats etc.
TF has ops that support all these things natively.
PyTorch has the same issue. It provides a thin wrapper around multiprocessing which is just another library for scheduling tasks into a pool of threads/processes but PyTorch itself doesn't know how to parse a JPEG. The big difference is that TF embeds preprocessing into the TF graph so it's mo…