-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HMC #4
Comments
This is just for myself @alecandido, same considerations of #2 (comment) |
The problem with pymc3 is that we want to take second derivatives, and then pymc3 takes the gradient on top of that. pymc4 supports JAX as backend so maybe (not sure) you can do it easily, in pymc3 you would need to code all derivatives manually, and redo it if you change the model (see slide 24 in my seminar on lsqfitgp). Currently in the tests with lsqfitgp I'm doing a Laplace approximation for the nonlinearities and the hyperparameters, after appropriately transforming the hyperparameters (log for positive, etc.). This often works well, considering in particular that we are not interested in the hyperparameters per se, we only use them as a way of specifying a flexible prior distribution of the PDFs, we care about the predictive error and not about getting right the tails of the posteriors of the hyperparameters. Moreover I've seen the error on the current fitted PDFs and it's small, so overall I think the fit would work without MCMC. If we end up really needing it, we could first test the fit with lsqfitgp where it's easy to change the model and then hardcode everything in pymc3 when we are sure of what kernels we want to use. Other alternatives are using JAX-based beta software (numpyro, pymc4, tinygp) and do some stuff on our own but not computing all derivatives, but considering that lsqfitgp is written with autograd I could as well port lsqfitgp to JAX and then plug its marginal likelihood into any NUTS implementation. |
Ok, the idea is to have a NUTS-based implementation, about the details of the library providing it I don't have any preference. I guess the proof of concept is worth the effort, even if we end up choosing pure I'm taking the burden of providing one or the other implementation, but if you can help me, I'd be glad to accept your insights (and even practical help). Furthermore, I had a look at the implementation of NUTS in
thus even an implementation from scratch should not be hard (even if I'd study it and the abstract algorithm before, and cook up my own implementation in my own optimized/favorite way). |
Here the roadmap for HMC implementation:
pymc3
The text was updated successfully, but these errors were encountered: