From 0167ec40b297380c0c373613318394c0c33a5a87 Mon Sep 17 00:00:00 2001 From: Vinay Anantharaman Date: Mon, 7 Mar 2022 11:02:52 -0800 Subject: [PATCH] [components][aws_sagemaker] Handle Stopped State --- .../unit_tests/tests/train/test_train_component.py | 13 +++++++++++++ .../train/src/sagemaker_training_component.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/components/aws/sagemaker/tests/unit_tests/tests/train/test_train_component.py b/components/aws/sagemaker/tests/unit_tests/tests/train/test_train_component.py index 33cba07df77..08976db410a 100644 --- a/components/aws/sagemaker/tests/unit_tests/tests/train/test_train_component.py +++ b/components/aws/sagemaker/tests/unit_tests/tests/train/test_train_component.py @@ -120,6 +120,19 @@ def test_get_job_status(self): SageMakerJobStatus(is_completed=True, raw_status="Completed"), ) + self.component._get_debug_rule_status = MagicMock( + return_value=SageMakerJobStatus( + is_completed=True, has_error=False, raw_status="Stopped" + ) + ) + self.component._sm_client.describe_training_job.return_value = { + "TrainingJobStatus": "Stopped" + } + self.assertEqual( + self.component._get_job_status(), + SageMakerJobStatus(is_completed=True, raw_status="Stopped"), + ) + self.component._sm_client.describe_training_job.return_value = { "TrainingJobStatus": "Failed", "FailureReason": "lolidk", diff --git a/components/aws/sagemaker/train/src/sagemaker_training_component.py b/components/aws/sagemaker/train/src/sagemaker_training_component.py index ab7effeafda..bc4e816afcd 100644 --- a/components/aws/sagemaker/train/src/sagemaker_training_component.py +++ b/components/aws/sagemaker/train/src/sagemaker_training_component.py @@ -54,7 +54,7 @@ def _get_job_status(self) -> SageMakerJobStatus: ) status = response["TrainingJobStatus"] - if status == "Completed": + if status == "Completed" or status == "Stopped": return self._get_debug_rule_status() if status == "Failed": message = response["FailureReason"]