Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding r2max function #19

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions arsenic/stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,43 @@
import numpy as np

def bootstrap_statistic(y_true, y_pred, dy_true=None, dy_pred=None, ci=0.95, statistic='RMSE', nbootstrap = 1000, plot_type='dG'):

def calc_R2max(affinities, errors):
"""Calculates the maximum achievable R2 given a set of affinities and their uncertainties.

[1] https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b01067
Sheridan, Robert P., et al. "Experimental Error, Kurtosis, Activity Cliffs, and Methodology: What Limits the
Predictivity of Quantitative Structure–Activity Relationship Models?."
Journal of Chemical Information and Modeling
60.4 (2020): 1969-1982.

Parameters
----------
affinities : list[float]
List of affinities for a set (order not important)
errors : list[float] or float
Either a single experimental error for a set of compounds
or a per-ligand list of errors (order not important)

Returns
-------
float
value of R2max

"""
if isinstance(errors, float):
errors = [errors for _ in affinities]
assert len(errors) == len(affinities), f'Need the same number of errors as affinities.\
{len(errors)} errors for {len(affinities)} affinities,\
or one error as a float for the whole sett'
return 1. - (np.std(affinities) / np.std(errors))**0.5


def bootstrap_statistic(y_true, y_pred,
dy_true=None, dy_pred=None,
ci=0.95,
statistic='RMSE',
nbootstrap=1000,
plot_type='dG'):
import sklearn.metrics
import scipy
"""Compute mean and confidence intervals of specified statistic.
Expand Down Expand Up @@ -89,11 +126,12 @@ def unique_differences(x):
assert len(y_true) == len(dy_true)
assert len(y_true) == len(dy_pred)
sample_size = len(y_true)
s_n = np.zeros([nbootstrap], np.float64) # s_n[n] is the statistic computed for bootstrap sample n
# s_n[n] is the statistic computed for bootstrap sample n
s_n = np.zeros([nbootstrap], np.float64)
for replicate in range(nbootstrap):
y_true_sample = np.zeros_like(y_true)
y_pred_sample = np.zeros_like(y_pred)
for i,j in enumerate(np.random.choice(np.arange(sample_size), size=[sample_size], replace=True)):
for i, j in enumerate(np.random.choice(np.arange(sample_size), size=[sample_size], replace=True)):
y_true_sample[i] = np.random.normal(loc=y_true[j], scale=np.fabs(dy_true[j]), size=1)
y_pred_sample[i] = np.random.normal(loc=y_pred[j], scale=np.fabs(dy_pred[j]), size=1)
s_n[replicate] = compute_statistic(y_true_sample, y_pred_sample, statistic)
Expand Down Expand Up @@ -121,7 +159,9 @@ def mle(g, factor='f_ij', node_factor=None):
We assume the free energy of node 0 is zero.

Reference : https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00528
Xu, Huafeng. "Optimal measurement network of pairwise differences." Journal of Chemical Information and Modeling 59.11 (2019): 4720-4728.
Xu, Huafeng. "Optimal measurement network of pairwise differences."
Journal of Chemical Information and Modeling
59.11 (2019): 4720-4728.

Parameters
----------
Expand Down Expand Up @@ -151,7 +191,8 @@ def mle(g, factor='f_ij', node_factor=None):
df_ij = form_edge_matrix(g, factor.replace('_', '_d'), action='symmetrize')
else:
f_ij = form_edge_matrix(g, factor, action='antisymmetrize', node_label=node_factor)
df_ij = form_edge_matrix(g, factor.replace('_', '_d'), action='symmetrize', node_label=node_factor.replace('_', '_d'))
df_ij = form_edge_matrix(g, factor.replace('_', '_d'), action='symmetrize',
node_label=node_factor.replace('_', '_d'))

node_name_to_index = {}
for i, name in enumerate(g.nodes()):
Expand Down