Skip to content

Commit

Permalink
Make prediction query actually work with the probability_space option. (
Browse files Browse the repository at this point in the history
facebookresearch#474)

Summary:
Pull Request resolved: facebookresearch#474

Prediction query messages used to called a non-existent strategy method called predict_probability. This meant that there would be an exception whenever this is tried. The method hasn't existed for at least 2 years.

Fixed to use `strat.predict()` with the correct probability_space flag instead.

Because of this change, further changes was necessary to ensure models that would return tneors with gradients would work (e.g., pairwise probit).

Reviewed By: crasanders

Differential Revision: D66997335

fbshipit-source-id: e8243ea08e8bcb37814f5275577d9fb1046408ed
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Dec 10, 2024
1 parent 62aab89 commit 28243d7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
4 changes: 2 additions & 2 deletions aepsych/models/pairwise_probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def predict(

if probability_space:
return (
promote_0d(norm.cdf(fmean)),
promote_0d(norm.cdf(fvar)),
promote_0d(norm.cdf(fmean.detach().numpy())),
promote_0d(norm.cdf(fvar.detach().numpy())),
)
else:
return fmean, fvar
Expand Down
16 changes: 7 additions & 9 deletions aepsych/server/message_handlers/handle_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ def query(
# returns the model value at x
if x is None: # TODO: ensure if x is between lb and ub
raise RuntimeError("Cannot query model at location = None!")
if probability_space:
mean, _var = server.strat.predict_probability(
server._config_to_tensor(x).unsqueeze(axis=0),
)
else:
mean, _var = server.strat.predict(
server._config_to_tensor(x).unsqueeze(axis=0)
)

mean, _var = server.strat.predict(
server._config_to_tensor(x).unsqueeze(axis=0),
probability_space=probability_space,
)
response["x"] = x
response["y"] = np.array(mean.item()) # mean.item()
y = mean.item() if isinstance(mean, torch.Tensor) else mean[0]
response["y"] = np.array(y) # mean.item()

elif query_type == "inverse":
nearest_y, nearest_loc = server.strat.inv_query(
Expand Down
10 changes: 10 additions & 0 deletions tests/server/message_handlers/test_query_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,25 @@ def test_grad_model_smoketest(self):
"x": {"x": [0.0], "y": [1.0]},
},
}
query_pred_prob = {
"type": "query",
"message": {
"query_type": "prediction",
"x": {"x": [0.0], "y": [1.0]},
"probability_space": True,
},
}
query_inv_req = {
"type": "query",
"message": {
"query_type": "inverse",
"y": 5.0,
},
}

self.s.handle_request(query_min_req)
self.s.handle_request(query_pred_req)
self.s.handle_request(query_pred_prob)
self.s.handle_request(query_max_req)
self.s.handle_request(query_inv_req)

Expand Down

0 comments on commit 28243d7

Please sign in to comment.