Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Undeploy models with no WorkerNodes #3380

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
Expand All @@ -32,9 +37,11 @@
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
Expand All @@ -51,6 +58,7 @@
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -156,11 +164,68 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUnde
private void undeployModels(String[] targetNodeIds, String[] modelIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
listener.onResponse(new MLUndeployModelsResponse(r));
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(response -> {
/*
* The method TransportUndeployModelsAction.processUndeployModelResponseAndUpdate(...) performs
* undeploy action of models by removing the models from the nodes cache and updating the index when it's able to find it.
*
* The problem becomes when the models index is incorrect and no node(s) are servicing the model. This results in
* `{}` responses (on undeploy action), with no update to the model index thus, causing incorrect model state status.
*
* Having this change enables a check that this edge case occurs along with having access to the model id
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
*/
if (response.getNodes().isEmpty()) {
brianf-aws marked this conversation as resolved.
Show resolved Hide resolved
bulkSetModelIndexToUndeploy(modelIds, listener, response);
return;
}
listener.onResponse(new MLUndeployModelsResponse(response));
}, listener::onFailure));
}

private void bulkSetModelIndexToUndeploy(
String[] modelIds,
ActionListener<MLUndeployModelsResponse> listener,
MLUndeployModelNodesResponse response
) {
BulkRequest bulkUpdateRequest = new BulkRequest();
for (String modelId : modelIds) {
UpdateRequest updateRequest = new UpdateRequest();

ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());

builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);

builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
bulkUpdateRequest.add(updateRequest);
}

bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
log.info("No nodes service: {}", Arrays.toString(modelIds));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about clarifying more ?

Suggested change
log.info("No nodes service: {}", Arrays.toString(modelIds));
log.info("No nodes running these models: {}", Arrays.toString(modelIds));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid feedback added to commit 77f6e5b


try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
.runBefore(listener, () -> threadContext.restore());
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
log.debug("Successfully set modelIds to UNDEPLOY in index");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add model ids to log?

listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
}, e -> {
log.error("Failed to set modelIds to UNDEPLOY in index", e);
listenerWithContextRestoration.onFailure(e);
});

client.bulk(bulkUpdateRequest, bulkResponseListener);
} catch (Exception e) {
log.error("Unexpected error while setting modelIds to UNDEPLOY status to index", e);
listener.onFailure(e);
}

}

private void validateAccess(String modelId, ActionListener<Boolean> listener) {
User user = RestActionUtils.getUserContext(client);
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
Expand All @@ -29,7 +32,10 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -42,6 +48,7 @@
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
Expand Down Expand Up @@ -164,6 +171,129 @@ public void setup() throws IOException {
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));
}

public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);

// mock the bulk response that can be captured for inspecting the contents of the write to index
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);

transportUndeployModelsAction.doExecute(task, request, actionListener);

BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue();
assertEquals(1, capturedBulkRequest.numberOfActions());
UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0);

@SuppressWarnings("unchecked")
Map<String, Object> updateDoc = updateRequest.doc().sourceAsMap();
String modelIdFromBulkRequest = updateRequest.id();
String indexNameFromBulkRequest = updateRequest.index();

assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest);
assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest);

assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD));
assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD));
assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD));
assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD));
assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD));

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
responseList.add(mock(MLUndeployModelNodeResponse.class));
responseList.add(mock(MLUndeployModelNodeResponse.class));
List<FailedNodeException> failuresList = new ArrayList<>();
failuresList.add(mock(FailedNodeException.class));
failuresList.add(mock(FailedNodeException.class));

MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with nodes associated to the model
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);

transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
// Check that no bulk write occurred Since there were nodes servicing the model
verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testHiddenModelSuccess() {
MLModel mlModel = MLModel
.builder()
Expand All @@ -186,16 +316,28 @@ public void testHiddenModelSuccess() {
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testHiddenModelPermissionError() {
Expand Down Expand Up @@ -249,9 +391,19 @@ public void testDoExecute() {
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));
// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
transportUndeployModelsAction.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_modelAccessControl_notEnabled() {
Expand Down
Loading