Skip to content

Commit

Permalink
Merge branch 'develop' into aen_wx_1878
Browse files Browse the repository at this point in the history
  • Loading branch information
aednichols authored Dec 11, 2024
2 parents 7fd893e + 38245ca commit 424e55a
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 16 deletions.
4 changes: 3 additions & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ object Dependencies {
private val metrics3StatsdV = "4.2.0"
private val mockFtpServerV = "3.0.0"
private val mockitoV = "3.12.4"
private val mockitoInlineV = "2.8.9"
private val mockserverNettyV = "5.14.0"
private val mouseV = "1.0.11"

Expand Down Expand Up @@ -625,7 +626,8 @@ object Dependencies {
"org.scalatest" %% "scalatest" % scalatestV,
// Use mockito Java DSL directly instead of the numerous and often hard to keep updated Scala DSLs.
// See also scaladoc in common.mock.MockSugar and that trait's various usages.
"org.mockito" % "mockito-core" % mockitoV
"org.mockito" % "mockito-core" % mockitoV,
"org.mockito" % "mockito-inline" % mockitoInlineV
) ++ slf4jBindingDependencies // During testing, add an slf4j binding for _all_ libraries.

val kindProjectorPlugin = "org.typelevel" % "kind-projector" % kindProjectorV cross CrossVersion.full
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime
}

override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] =
pollStatus.instantiatedVmInfo

override def handleVmCostLookup(vmInfo: InstantiatedVmInfo) = {
val request = GcpCostLookupRequest(vmInfo, self)
params.serviceRegistry ! request
Expand All @@ -69,6 +72,7 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
}

override def receive: Receive = {
case costResponse: GcpCostLookupResponse => handleCostResponse(costResponse)
case message: PollResultMessage =>
message match {
case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult)
Expand All @@ -93,5 +97,4 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)

override def params: PollMonitorParameters = pollMonitorParameters

override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = Option.empty // TODO
}
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,18 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
} yield status
}

override val pollingResultMonitorActor: Option[ActorRef] = Option(
context.actorOf(
BatchPollResultMonitorActor.props(serviceRegistryActor,
workflowDescriptor,
jobDescriptor,
validatedRuntimeAttributes,
platform,
jobLogger
)
)
)

override def isTerminal(runStatus: RunStatus): Boolean =
runStatus match {
case _: RunStatus.TerminalRunStatus => true
Expand Down Expand Up @@ -1070,7 +1082,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
Future.fromTry {
Try {
runStatus match {
case RunStatus.Aborted(_) => AbortedExecutionHandle
case RunStatus.Aborted(_, _) => AbortedExecutionHandle
case failedStatus: RunStatus.UnsuccessfulRunStatus => handleFailedRunStatus(failedStatus)
case unknown =>
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cromwell.backend.google.batch.api.request

import com.google.api.gax.rpc.{ApiException, StatusCode}
import com.google.cloud.batch.v1.AllocationPolicy.ProvisioningModel
import com.google.cloud.batch.v1._
import com.typesafe.scalalogging.LazyLogging
import cromwell.backend.google.batch.actors.BatchApiAbortClient.{
Expand All @@ -11,6 +12,8 @@ import cromwell.backend.google.batch.api.BatchApiRequestManager._
import cromwell.backend.google.batch.api.{BatchApiRequestManager, BatchApiResponse}
import cromwell.backend.google.batch.models.{GcpBatchExitCode, RunStatus}
import cromwell.core.ExecutionEvent
import cromwell.services.cost.InstantiatedVmInfo
import cromwell.services.metadata.CallMetadataKeys

import scala.annotation.unused
import scala.concurrent.{ExecutionContext, Future, Promise}
Expand Down Expand Up @@ -136,14 +139,32 @@ object BatchRequestExecutor {
)
lazy val exitCode = findBatchExitCode(events)

// Get vm info for this job
val allocationPolicy = job.getAllocationPolicy

// Get instances that can be created with this AllocationPolicy, only instances[0] is supported
val instancePolicy = allocationPolicy.getInstances(0).getPolicy
val machineType = instancePolicy.getMachineType
val preemtible = instancePolicy.getProvisioningModelValue == ProvisioningModel.PREEMPTIBLE.getNumber

// location list = [regions/us-central1, zones/us-central1-b], region is the first element
val location = allocationPolicy.getLocation.getAllowedLocationsList.get(0)
val region =
if (location.isEmpty)
"us-central1"
else
location.split("/").last

val instantiatedVmInfo = Some(InstantiatedVmInfo(region, machineType, preemtible))

if (job.getStatus.getState == JobStatus.State.SUCCEEDED) {
RunStatus.Success(events)
RunStatus.Success(events, instantiatedVmInfo)
} else if (job.getStatus.getState == JobStatus.State.RUNNING) {
RunStatus.Running(events)
RunStatus.Running(events, instantiatedVmInfo)
} else if (job.getStatus.getState == JobStatus.State.FAILED) {
RunStatus.Failed(exitCode, events)
RunStatus.Failed(exitCode, events, instantiatedVmInfo)
} else {
RunStatus.Initializing(events)
RunStatus.Initializing(events, instantiatedVmInfo)
}
}

Expand All @@ -152,12 +173,27 @@ object BatchRequestExecutor {
GcpBatchExitCode.fromEventMessage(e.name.toLowerCase)
}.headOption

private def getEventList(events: List[StatusEvent]): List[ExecutionEvent] =
events.map { e =>
private def getEventList(events: List[StatusEvent]): List[ExecutionEvent] = {
val startedRegex = ".*SCHEDULED to RUNNING.*".r
val endedRegex = ".*RUNNING to.*".r // can be SUCCEEDED or FAILED
events.flatMap { e =>
val time = java.time.Instant
.ofEpochSecond(e.getEventTime.getSeconds, e.getEventTime.getNanos.toLong)
.atOffset(java.time.ZoneOffset.UTC)
ExecutionEvent(name = e.getDescription, offsetDateTime = time)
val eventType = e.getDescription match {
case startedRegex() => CallMetadataKeys.VmStartTime
case endedRegex() => CallMetadataKeys.VmEndTime
case _ => e.getType
}
val executionEvents = List(ExecutionEvent(name = eventType, offsetDateTime = time))

// Add an additional ExecutionEvent to capture other info if the event is a VmStartTime or VmEndTime
if (eventType == CallMetadataKeys.VmStartTime || eventType == CallMetadataKeys.VmEndTime) {
executionEvents :+ ExecutionEvent(name = e.getDescription, offsetDateTime = time)
} else {
executionEvents
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
package cromwell.backend.google.batch.models

import cromwell.core.ExecutionEvent
import cromwell.services.cost.InstantiatedVmInfo

sealed trait RunStatus {
def eventList: Seq[ExecutionEvent]
def toString: String

val instantiatedVmInfo: Option[InstantiatedVmInfo]
}

object RunStatus {

case class Initializing(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Initializing" }
case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent]) extends RunStatus {
case class Initializing(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
extends RunStatus { override def toString = "Initializing" }
case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent],
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
) extends RunStatus {
override def toString = "AwaitingCloudQuota"
}

case class Running(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Running" }
case class Running(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
extends RunStatus { override def toString = "Running" }

sealed trait TerminalRunStatus extends RunStatus

case class Success(eventList: Seq[ExecutionEvent]) extends TerminalRunStatus {
case class Success(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
extends TerminalRunStatus {
override def toString = "Success"
}

Expand All @@ -29,7 +37,8 @@ object RunStatus {

final case class Failed(
exitCode: Option[GcpBatchExitCode],
eventList: Seq[ExecutionEvent]
eventList: Seq[ExecutionEvent],
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
) extends UnsuccessfulRunStatus {
override def toString = "Failed"

Expand Down Expand Up @@ -58,7 +67,9 @@ object RunStatus {
}
}

final case class Aborted(eventList: Seq[ExecutionEvent]) extends UnsuccessfulRunStatus {
final case class Aborted(eventList: Seq[ExecutionEvent],
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
) extends UnsuccessfulRunStatus {
override def toString = "Aborted"

override val exitCode: Option[GcpBatchExitCode] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package cromwell.backend.google.batch.actors

import akka.actor.{ActorRef, ActorSystem, Props}
import akka.testkit.{TestKit, TestProbe}
import cats.data.Validated.Valid
import common.mock.MockSugar
import cromwell.backend.google.batch.models.GcpBatchRuntimeAttributes
import cromwell.backend.{BackendJobDescriptor, BackendJobDescriptorKey, RuntimeAttributeDefinition}
import cromwell.core.callcaching.NoDocker
import cromwell.core.{ExecutionEvent, WorkflowOptions}
import cromwell.core.logging.JobLogger
import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo}
import cromwell.services.keyvalue.InMemoryKvServiceActor
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.matchers.should.Matchers
import cromwell.backend.google.batch.models.GcpBatchTestConfig._
import wom.graph.CommandCallNode
import cromwell.backend._
import cromwell.backend.google.batch.models._
import cromwell.backend.io.TestWorkflows
import cromwell.backend.standard.pollmonitoring.ProcessThisPollResult
import cromwell.services.metadata.CallMetadataKeys
import cromwell.services.metadata.MetadataService.PutMetadataAction
import org.slf4j.helpers.NOPLogger
import wom.values.WomString

import java.time.{Instant, OffsetDateTime}
import java.time.temporal.ChronoUnit
import scala.concurrent.duration.DurationInt

class BatchPollResultMonitorActorSpec
extends TestKit(ActorSystem("BatchPollResultMonitorActorSpec"))
with AnyFlatSpecLike
with BackendSpec
with Matchers
with MockSugar {

var kvService: ActorRef = system.actorOf(Props(new InMemoryKvServiceActor), "kvService")
val runtimeAttributesBuilder = GcpBatchRuntimeAttributes.runtimeAttributesBuilder(gcpBatchConfiguration)
val jobLogger = mock[JobLogger]
val serviceRegistry = TestProbe()

val workflowDescriptor = buildWdlWorkflowDescriptor(TestWorkflows.HelloWorld)
val call: CommandCallNode = workflowDescriptor.callable.taskCallNodes.head
val jobKey = BackendJobDescriptorKey(call, None, 1)

val jobDescriptor = BackendJobDescriptor(workflowDescriptor,
jobKey,
runtimeAttributes = Map.empty,
evaluatedTaskInputs = Map.empty,
NoDocker,
None,
Map.empty
)

val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"))

val staticRuntimeAttributeDefinitions: Set[RuntimeAttributeDefinition] =
GcpBatchRuntimeAttributes.runtimeAttributesBuilder(GcpBatchTestConfig.gcpBatchConfiguration).definitions.toSet

val defaultedAttributes =
RuntimeAttributeDefinition.addDefaultsToAttributes(staticRuntimeAttributeDefinitions,
WorkflowOptions.fromMap(Map.empty).get
)(
runtimeAttributes
)
val validatedRuntimeAttributes = runtimeAttributesBuilder.build(defaultedAttributes, NOPLogger.NOP_LOGGER)

val actor = system.actorOf(
BatchPollResultMonitorActor.props(serviceRegistry.ref,
workflowDescriptor,
jobDescriptor,
validatedRuntimeAttributes,
Some(Gcp),
jobLogger
)
)
val vmInfo = InstantiatedVmInfo("europe-west9", "custom-16-32768", false)

behavior of "BatchPollResultMonitorActor"

it should "send a cost lookup request with the correct vm info after receiving a success pollResult" in {

val terminalPollResult =
RunStatus.Success(Seq(ExecutionEvent("fakeEvent", OffsetDateTime.now().truncatedTo(ChronoUnit.MILLIS))),
Some(vmInfo)
)
val message = ProcessThisPollResult(terminalPollResult)

actor ! message

serviceRegistry.expectMsgPF(1.seconds) { case m: GcpCostLookupRequest =>
m.vmInfo shouldBe vmInfo
}
}

it should "emit the correct cost metadata after receiving a costLookupResponse" in {

val costLookupResponse = GcpCostLookupResponse(vmInfo, Valid(BigDecimal(0.1)))

actor ! costLookupResponse

serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
val event = m.events.head
m.events.size shouldBe 1
event.key.key shouldBe CallMetadataKeys.VmCostPerHour
event.value.get.value shouldBe "0.1"
}
}

it should "emit the correct start time after receiving a running pollResult" in {

val vmStartTime = OffsetDateTime.now().minus(2, ChronoUnit.HOURS)
val pollResult = RunStatus.Running(
Seq(ExecutionEvent(CallMetadataKeys.VmStartTime, vmStartTime)),
Some(vmInfo)
)
val message = ProcessThisPollResult(pollResult)

actor ! message

serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
val event = m.events.head
m.events.size shouldBe 1
event.key.key shouldBe CallMetadataKeys.VmStartTime
assert(
Instant
.parse(event.value.get.value)
.equals(vmStartTime.toInstant.truncatedTo(ChronoUnit.MILLIS))
)
}
}

it should "emit the correct end time after receiving a running pollResult" in {

val vmEndTime = OffsetDateTime.now().minus(2, ChronoUnit.HOURS)
val pollResult = RunStatus.Running(
Seq(ExecutionEvent(CallMetadataKeys.VmEndTime, vmEndTime)),
Some(vmInfo)
)
val message = ProcessThisPollResult(pollResult)

actor ! message

serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
val event = m.events.head
m.events.size shouldBe 1
event.key.key shouldBe CallMetadataKeys.VmEndTime
assert(
Instant
.parse(event.value.get.value)
.equals(vmEndTime.toInstant.truncatedTo(ChronoUnit.MILLIS))
)
}
}
}
Loading

0 comments on commit 424e55a

Please sign in to comment.