Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MBAR bootstrap error #1077

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Use MBAR bootstrap error #1077

wants to merge 7 commits into from

Conversation

jthorton
Copy link
Collaborator

@jthorton jthorton commented Jan 14, 2025

Fixes #1012 by using the bootstrap error from pymbar3/4.

Would this be a good time to switch to only supporting pymbar4 so we only have to maintain a single interface for MBAR?

Note:

  • the full pymbar4 package brings in JAX
  • I found that 1000 iterations of bootstrapping only takes around 1 min for the default protocol (using jax)
  • For the extended charge changing protocol this can take up to 15 mins (using jax)
  • The variability in the dDGs between test runs was larger which meant I had to relax the relative tolerance on the tests

Checklist

  • Added a news entry

Developers certificate of origin

@jthorton jthorton requested review from IAlibay and atravitz January 14, 2025 17:21
Copy link

codecov bot commented Jan 14, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.65%. Comparing base (915d110) to head (9b8d3ad).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1077      +/-   ##
==========================================
- Coverage   94.46%   91.65%   -2.82%     
==========================================
  Files         135      135              
  Lines       10090    10083       -7     
==========================================
- Hits         9532     9242     -290     
- Misses        558      841     +283     
Flag Coverage Δ
fast-tests 91.65% <100.00%> (?)
slow-tests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

Would this be a good time to switch to only supporting pymbar4 so we only have to maintain a single interface for MBAR?

Yes I think that would be a good idea - if we think it's stable (we might have to benchmark a bit), we should make the jump.
If we go by spec0 rules pymbar 3 is > 2 years old.

PyMBAR 3 also has all kinds of stability issues we should try to avoid.

the full pymbar4 package brings in JAX

:/ how big of a dependency is JAX? It might be that we don't really have an option here. I know you can use pymbar 4 without JAX (that's how it gets deployed on PyPi). cc @atravitz

For the extended charge changing protocol this can take up to 15 mins (using jax)

Oof that's quite long. I guess as long as we're only doing that once in a multi-hour simulation it doesn't matter too much.

@jthorton
Copy link
Collaborator Author

JAX is around 60MB, but we can use pymbar-core which is the non-JAX version that should be a bit slower, how much slower, I am not sure but compared to a multi-hour simulation it should still be negligible!

  + jax                     0.4.35  pyhd8ed1ab_1         conda-forge/noarch        1MB
  + jaxlib                  0.4.35  cpu_py312hadfe8e1_0  conda-forge/osx-64       56MB

On the other hand, adding JAX is not too noticeable compared to the cudatoolkit?

Copy link
Member

@IAlibay IAlibay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of todos:

  • Could you add a news entry please?
  • Could you make the necessary changes to switch to pymbar 4 please?

np.array([0.07471 , 0.052914, 0.041508, 0.036613, 0.032827, 0.030489,
0.028154, 0.026529, 0.025284, 0.023968]),
rtol=1e-04,
np.array([0.077645, 0.054695, 0.044680, 0.03947, 0.034822,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating these - the new error values are expected to be different, so we should update things where we can.

rtol=1e-04,
np.array([0.077645, 0.054695, 0.044680, 0.03947, 0.034822,
0.033443, 0.030793, 0.028777, 0.026683, 0.026199]),
rtol=1e-01,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little bit loose as a tolerance, but I guess it's fine given the bootstraps are stochastic.

except AttributeError:
r = mbar.compute_free_energy_differences()
# pymbar 4
mbar = MBAR(u_ln, N_l, solver_protocol="robust", n_bootstraps=1000, bootstrap_solver_protocol="robust")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is most of the cost in the forward & reverse analysis?

Copy link
Collaborator Author

@jthorton jthorton Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah running the bootstrapping on repeat is expensive! One thought on the forward and backward estimates should we be subsampling using g_t calculated for this subset of data? In the industry benchmarking I calculated it 3 ways no subsampling, subsample based on the % of data and subsample using the g_t calculated for the full set of data. https://github.com/OpenFreeEnergy/IndustryBenchmarks2024/blob/fb60d7a971cb5d04787d796b6adcf257d905786a/industry_benchmarks/analysis/1_download_and_extract_data.py#L464-L552

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

On the other hand, adding JAX is not too noticeable compared to the cudatoolkit?

Yeah - I also suspect we're picking up a ton of dependencies elsewhere.

Long term maybe we should look into an openfe-base version that has the very minimal set of dependencies for everything.

I'll let @atravitz weigh in, but generally I'm ok / would very much like it if we pushed for pymbar4 w/ JAX.

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

Completely forgot to ask @jthorton - could you have a look through our docs and see if there's anywhere we can make it clear that this is now the bootstrap error? I know some folks got confused by it all.

@jthorton
Copy link
Collaborator Author

Currently blocked by perses=0.10.3 which pins to pymbar3.

Copy link

No API break detected ✅

Copy link
Member

@IAlibay IAlibay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jthorton I'm going to approve with the expectation that we can merge once tests pass once there's a new release of perses.

@jthorton
Copy link
Collaborator Author

We might need to delay the import of pymbar, while testing out the CLI for the partial charge generation I see that it prints a lot of info to the terminal, this gets even worse when using multiprocessing.

openfe charge-molecules -M malt1_ligands.sdf -o charged_ligands.sdf -w 6
Warning on use of the timeseries module: If the inherent timescales of the system are long compared to those being analyzed, this statistical inefficiency may be an underestimate.  The estimate presumes the use of many statistically independent samples.  Tests should be performed to assess whether this condition is satisfied.   Be cautious in the interpretation of the data.

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

SMALL MOLECULE PARTIAL CHARGE GENERATOR
_________________________________________

@atravitz
Copy link
Contributor

We might need to delay the import of pymbar, while testing out the CLI for the partial charge generation I see that it prints a lot of info to the terminal, this gets even worse when using multiprocessing.

openfe charge-molecules -M malt1_ligands.sdf -o charged_ligands.sdf -w 6
Warning on use of the timeseries module: If the inherent timescales of the system are long compared to those being analyzed, this statistical inefficiency may be an underestimate.  The estimate presumes the use of many statistically independent samples.  Tests should be performed to assess whether this condition is satisfied.   Be cautious in the interpretation of the data.

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

SMALL MOLECULE PARTIAL CHARGE GENERATOR
_________________________________________

Are we able to suppress just this warning? Otherwise, delaying the import sounds good to me.

@atravitz
Copy link
Contributor

Also, adding JAX as a dependency is fine by me

@IAlibay
Copy link
Member

IAlibay commented Jan 17, 2025

We might need to delay the import of pymbar, while testing out the CLI for the partial charge generation I see that it prints a lot of info to the terminal, this gets even worse when using multiprocessing.

openfe charge-molecules -M malt1_ligands.sdf -o charged_ligands.sdf -w 6
Warning on use of the timeseries module: If the inherent timescales of the system are long compared to those being analyzed, this statistical inefficiency may be an underestimate.  The estimate presumes the use of many statistically independent samples.  Tests should be performed to assess whether this condition is satisfied.   Be cautious in the interpretation of the data.

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

SMALL MOLECULE PARTIAL CHARGE GENERATOR
_________________________________________

Are we able to suppress just this warning? Otherwise, delaying the import sounds good to me.

Oof yeah @jthorton if you don't do the warning supression here, could you open an issue about doing it later? This is OpenMMTools levels of noisy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Switch to bootstrapping for MBAR errors.
3 participants