From f469ecbbd448d6fce8d731f4fd825ef968140522 Mon Sep 17 00:00:00 2001 From: Adam Nichols Date: Wed, 9 Aug 2023 16:23:10 -0400 Subject: [PATCH 01/25] WX-1078 ACR support (#7192) --- build.sbt | 1 + .../azure}/AzureCredentials.scala | 3 +- core/src/main/resources/reference.conf | 9 ++ .../docker/DockerImageIdentifier.scala | 11 +- .../cromwell/docker/DockerInfoActor.scala | 2 + .../registryv2/DockerRegistryV2Abstract.scala | 6 +- .../flows/azure/AcrAccessToken.scala | 3 + .../flows/azure/AcrRefreshToken.scala | 3 + .../flows/azure/AzureContainerRegistry.scala | 149 ++++++++++++++++++ .../docker/DockerImageIdentifierSpec.scala | 2 + .../cromwell/docker/DockerInfoActorSpec.scala | 14 +- .../blob/BlobFileSystemManager.scala | 2 +- 12 files changed, 197 insertions(+), 8 deletions(-) rename {filesystems/blob/src/main/scala/cromwell/filesystems/blob => cloudSupport/src/main/scala/cromwell/cloudsupport/azure}/AzureCredentials.scala (98%) create mode 100644 dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrAccessToken.scala create mode 100644 dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrRefreshToken.scala create mode 100644 dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AzureContainerRegistry.scala diff --git a/build.sbt b/build.sbt index 9bdd086d731..814e966b364 100644 --- a/build.sbt +++ b/build.sbt @@ -165,6 +165,7 @@ lazy val databaseMigration = (project in file("database/migration")) lazy val dockerHashing = project .withLibrarySettings("cromwell-docker-hashing", dockerHashingDependencies) + .dependsOn(cloudSupport) .dependsOn(core) .dependsOn(core % "test->test") .dependsOn(common % "test->test") diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/AzureCredentials.scala b/cloudSupport/src/main/scala/cromwell/cloudsupport/azure/AzureCredentials.scala similarity index 98% rename from filesystems/blob/src/main/scala/cromwell/filesystems/blob/AzureCredentials.scala rename to cloudSupport/src/main/scala/cromwell/cloudsupport/azure/AzureCredentials.scala index ae84e39adbe..c29155056a9 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/AzureCredentials.scala +++ b/cloudSupport/src/main/scala/cromwell/cloudsupport/azure/AzureCredentials.scala @@ -1,4 +1,4 @@ -package cromwell.filesystems.blob +package cromwell.cloudsupport.azure import cats.implicits.catsSyntaxValidatedId import com.azure.core.credential.TokenRequestContext @@ -9,7 +9,6 @@ import common.validation.ErrorOr.ErrorOr import scala.concurrent.duration._ import scala.jdk.DurationConverters._ - import scala.util.{Failure, Success, Try} /** diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index a3ac76e949c..d2cc9c5171f 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -411,6 +411,15 @@ docker { max-retries = 3 // Supported registries (Docker Hub, Google, Quay) can have additional configuration set separately + azure { + // Worst case `ReadOps per minute` value from official docs + // https://github.com/MicrosoftDocs/azure-docs/blob/main/includes/container-registry-limits.md + throttle { + number-of-requests = 1000 + per = 60 seconds + } + num-threads = 10 + } google { // Example of how to configure throttling, available for all supported registries throttle { diff --git a/dockerHashing/src/main/scala/cromwell/docker/DockerImageIdentifier.scala b/dockerHashing/src/main/scala/cromwell/docker/DockerImageIdentifier.scala index 9fbd173303b..a798f351f17 100644 --- a/dockerHashing/src/main/scala/cromwell/docker/DockerImageIdentifier.scala +++ b/dockerHashing/src/main/scala/cromwell/docker/DockerImageIdentifier.scala @@ -1,5 +1,7 @@ package cromwell.docker +import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry + import scala.util.{Failure, Success, Try} sealed trait DockerImageIdentifier { @@ -14,7 +16,14 @@ sealed trait DockerImageIdentifier { lazy val name = repository map { r => s"$r/$image" } getOrElse image // The name of the image with a repository prefix if a repository was specified, or with a default repository prefix of // "library" if no repository was specified. - lazy val nameWithDefaultRepository = repository.getOrElse("library") + s"/$image" + lazy val nameWithDefaultRepository = { + // In ACR, the repository is part of the registry domain instead of the path + // e.g. `terrabatchdev.azurecr.io` + if (host.exists(_.contains(AzureContainerRegistry.domain))) + image + else + repository.getOrElse("library") + s"/$image" + } lazy val hostAsString = host map { h => s"$h/" } getOrElse "" // The full name of this image, including a repository prefix only if a repository was explicitly specified. lazy val fullName = s"$hostAsString$name:$reference" diff --git a/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala b/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala index 40a4c74cb9b..3ebb8d98f39 100644 --- a/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala +++ b/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala @@ -14,6 +14,7 @@ import cromwell.core.actor.StreamIntegration.{BackPressure, StreamContext} import cromwell.core.{Dispatcher, DockerConfiguration} import cromwell.docker.DockerInfoActor._ import cromwell.docker.registryv2.DockerRegistryV2Abstract +import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry import cromwell.docker.registryv2.flows.dockerhub.DockerHubRegistry import cromwell.docker.registryv2.flows.google.GoogleRegistry import cromwell.docker.registryv2.flows.quay.QuayRegistry @@ -232,6 +233,7 @@ object DockerInfoActor { // To add a new registry, simply add it to that list List( + ("azure", { c: DockerRegistryConfig => new AzureContainerRegistry(c) }), ("dockerhub", { c: DockerRegistryConfig => new DockerHubRegistry(c) }), ("google", { c: DockerRegistryConfig => new GoogleRegistry(c) }), ("quay", { c: DockerRegistryConfig => new QuayRegistry(c) }) diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala index a7cc1969903..bb25cb4bc3d 100644 --- a/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala @@ -107,7 +107,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi } // Execute a request. No retries because they're expected to already be handled by the client - private def executeRequest[A](request: IO[Request[IO]], handler: Response[IO] => IO[A])(implicit client: Client[IO]): IO[A] = { + protected def executeRequest[A](request: IO[Request[IO]], handler: Response[IO] => IO[A])(implicit client: Client[IO]): IO[A] = { request.flatMap(client.run(_).use[IO, A](handler)) } @@ -188,7 +188,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi /** * Builds the token request */ - private def buildTokenRequest(dockerInfoContext: DockerInfoContext): IO[Request[IO]] = { + protected def buildTokenRequest(dockerInfoContext: DockerInfoContext): IO[Request[IO]] = { val request = Method.GET( buildTokenRequestUri(dockerInfoContext.dockerImageID), buildTokenRequestHeaders(dockerInfoContext): _* @@ -220,7 +220,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi * Request to get the manifest, using the auth token if provided */ private def manifestRequest(token: Option[String], imageId: DockerImageIdentifier, manifestHeader: Accept): IO[Request[IO]] = { - val authorizationHeader = token.map(t => Authorization(Credentials.Token(AuthScheme.Bearer, t))) + val authorizationHeader: Option[Authorization] = token.map(t => Authorization(Credentials.Token(AuthScheme.Bearer, t))) val request = Method.GET( buildManifestUri(imageId), List( diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrAccessToken.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrAccessToken.scala new file mode 100644 index 00000000000..bf0841e2547 --- /dev/null +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrAccessToken.scala @@ -0,0 +1,3 @@ +package cromwell.docker.registryv2.flows.azure + +case class AcrAccessToken(access_token: String) diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrRefreshToken.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrRefreshToken.scala new file mode 100644 index 00000000000..aa6a6d17eb5 --- /dev/null +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AcrRefreshToken.scala @@ -0,0 +1,3 @@ +package cromwell.docker.registryv2.flows.azure + +case class AcrRefreshToken(refresh_token: String) diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AzureContainerRegistry.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AzureContainerRegistry.scala new file mode 100644 index 00000000000..46dfd116bc6 --- /dev/null +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/azure/AzureContainerRegistry.scala @@ -0,0 +1,149 @@ +package cromwell.docker.registryv2.flows.azure + +import cats.data.Validated.{Invalid, Valid} +import cats.effect.IO +import com.typesafe.scalalogging.LazyLogging +import common.validation.ErrorOr.ErrorOr +import cromwell.cloudsupport.azure.AzureCredentials +import cromwell.docker.DockerInfoActor.DockerInfoContext +import cromwell.docker.{DockerImageIdentifier, DockerRegistryConfig} +import cromwell.docker.registryv2.DockerRegistryV2Abstract +import org.http4s.{Header, Request, Response, Status} +import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry.domain +import org.http4s.circe.jsonOf +import org.http4s.client.Client +import io.circe.generic.auto._ +import org.http4s._ + + +class AzureContainerRegistry(config: DockerRegistryConfig) extends DockerRegistryV2Abstract(config) with LazyLogging { + + /** + * (e.g registry-1.docker.io) + */ + override protected def registryHostName(dockerImageIdentifier: DockerImageIdentifier): String = + dockerImageIdentifier.host.getOrElse("") + + override def accepts(dockerImageIdentifier: DockerImageIdentifier): Boolean = + dockerImageIdentifier.hostAsString.contains(domain) + + override protected def authorizationServerHostName(dockerImageIdentifier: DockerImageIdentifier): String = + dockerImageIdentifier.host.getOrElse("") + + /** + * In Azure, service name does not exist at the registry level, it varies per repo, e.g. `terrabatchdev.azurecr.io` + */ + override def serviceName: Option[String] = + throw new Exception("ACR service name is host of user-defined registry, must derive from `DockerImageIdentifier`") + + /** + * Builds the list of headers for the token request + */ + override protected def buildTokenRequestHeaders(dockerInfoContext: DockerInfoContext): List[Header] = { + List(contentTypeHeader) + } + + private val contentTypeHeader: Header = { + import org.http4s.headers.`Content-Type` + import org.http4s.MediaType + + `Content-Type`(MediaType.application.`x-www-form-urlencoded`) + } + + private def getRefreshToken(authServerHostname: String, defaultAccessToken: String): IO[Request[IO]] = { + import org.http4s.Uri.{Authority, Scheme} + import org.http4s.client.dsl.io._ + import org.http4s._ + + val uri = Uri.apply( + scheme = Option(Scheme.https), + authority = Option(Authority(host = Uri.RegName(authServerHostname))), + path = "/oauth2/exchange", + query = Query.empty + ) + + org.http4s.Method.POST( + UrlForm( + "service" -> authServerHostname, + "access_token" -> defaultAccessToken, + "grant_type" -> "access_token" + ), + uri, + List(contentTypeHeader): _* + ) + } + + /* + Unlike other repositories, Azure reserves `GET /oauth2/token` for Basic Authentication [0] + In order to use Oauth we must `POST /oauth2/token` [1] + + [0] https://github.com/Azure/acr/blob/main/docs/Token-BasicAuth.md#using-the-token-api + [1] https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#calling-post-oauth2token-to-get-an-acr-access-token + */ + private def getDockerAccessToken(hostname: String, repository: String, refreshToken: String): IO[Request[IO]] = { + import org.http4s.Uri.{Authority, Scheme} + import org.http4s.client.dsl.io._ + import org.http4s._ + + val uri = Uri.apply( + scheme = Option(Scheme.https), + authority = Option(Authority(host = Uri.RegName(hostname))), + path = "/oauth2/token", + query = Query.empty + ) + + org.http4s.Method.POST( + UrlForm( + // Tricky behavior - invalid `repository` values return a 200 with a valid-looking token. + // However, the token will cause 401s on all subsequent requests. + "scope" -> s"repository:$repository:pull", + "service" -> hostname, + "refresh_token" -> refreshToken, + "grant_type" -> "refresh_token" + ), + uri, + List(contentTypeHeader): _* + ) + } + + override protected def getToken(dockerInfoContext: DockerInfoContext)(implicit client: Client[IO]): IO[Option[String]] = { + val hostname = authorizationServerHostName(dockerInfoContext.dockerImageID) + val maybeAadAccessToken: ErrorOr[String] = AzureCredentials.getAccessToken(None) // AAD token suitable for get-refresh-token request + val repository = dockerInfoContext.dockerImageID.image // ACR uses what we think of image name, as the repository + + // Top-level flow: AAD access token -> refresh token -> ACR access token + maybeAadAccessToken match { + case Valid(accessToken) => + (for { + refreshToken <- executeRequest(getRefreshToken(hostname, accessToken), parseRefreshToken) + dockerToken <- executeRequest(getDockerAccessToken(hostname, repository, refreshToken), parseAccessToken) + } yield dockerToken).map(Option.apply) + case Invalid(errors) => + IO.raiseError( + new Exception(s"Could not obtain AAD token to exchange for ACR refresh token. Error(s): ${errors}") + ) + } + } + + implicit val refreshTokenDecoder: EntityDecoder[IO, AcrRefreshToken] = jsonOf[IO, AcrRefreshToken] + implicit val accessTokenDecoder: EntityDecoder[IO, AcrAccessToken] = jsonOf[IO, AcrAccessToken] + + private def parseRefreshToken(response: Response[IO]): IO[String] = response match { + case Status.Successful(r) => r.as[AcrRefreshToken].map(_.refresh_token) + case r => + r.as[String].flatMap(b => IO.raiseError(new Exception(s"Request failed with status ${r.status.code} and body $b"))) + } + + private def parseAccessToken(response: Response[IO]): IO[String] = response match { + case Status.Successful(r) => r.as[AcrAccessToken].map(_.access_token) + case r => + r.as[String].flatMap(b => IO.raiseError(new Exception(s"Request failed with status ${r.status.code} and body $b"))) + } + +} + +object AzureContainerRegistry { + + def domain: String = "azurecr.io" + +} diff --git a/dockerHashing/src/test/scala/cromwell/docker/DockerImageIdentifierSpec.scala b/dockerHashing/src/test/scala/cromwell/docker/DockerImageIdentifierSpec.scala index 00c738dbede..41353934fc6 100644 --- a/dockerHashing/src/test/scala/cromwell/docker/DockerImageIdentifierSpec.scala +++ b/dockerHashing/src/test/scala/cromwell/docker/DockerImageIdentifierSpec.scala @@ -18,6 +18,7 @@ class DockerImageIdentifierSpec extends AnyFlatSpec with CromwellTimeoutSpec wit ("broad/cromwell/submarine", None, Option("broad/cromwell"), "submarine", "latest"), ("gcr.io/google/slim", Option("gcr.io"), Option("google"), "slim", "latest"), ("us-central1-docker.pkg.dev/google/slim", Option("us-central1-docker.pkg.dev"), Option("google"), "slim", "latest"), + ("terrabatchdev.azurecr.io/postgres", Option("terrabatchdev.azurecr.io"), None, "postgres", "latest"), // With tags ("ubuntu:latest", None, None, "ubuntu", "latest"), ("ubuntu:1235-SNAP", None, None, "ubuntu", "1235-SNAP"), @@ -25,6 +26,7 @@ class DockerImageIdentifierSpec extends AnyFlatSpec with CromwellTimeoutSpec wit ("index.docker.io:9999/ubuntu:170904", Option("index.docker.io:9999"), None, "ubuntu", "170904"), ("localhost:5000/capture/transwf:170904", Option("localhost:5000"), Option("capture"), "transwf", "170904"), ("quay.io/biocontainers/platypus-variant:0.8.1.1--htslib1.5_0", Option("quay.io"), Option("biocontainers"), "platypus-variant", "0.8.1.1--htslib1.5_0"), + ("terrabatchdev.azurecr.io/postgres:latest", Option("terrabatchdev.azurecr.io"), None, "postgres", "latest"), // Very long tags with trailing spaces cause problems for the re engine ("someuser/someimage:supercalifragilisticexpialidociouseventhoughthesoundofitissomethingquiteatrociousifyousayitloudenoughyoullalwayssoundprecocious ", None, Some("someuser"), "someimage", "supercalifragilisticexpialidociouseventhoughthesoundofitissomethingquiteatrociousifyousayitloudenoughyoullalwayssoundprecocious") ) diff --git a/dockerHashing/src/test/scala/cromwell/docker/DockerInfoActorSpec.scala b/dockerHashing/src/test/scala/cromwell/docker/DockerInfoActorSpec.scala index e41be33f762..72baec70825 100644 --- a/dockerHashing/src/test/scala/cromwell/docker/DockerInfoActorSpec.scala +++ b/dockerHashing/src/test/scala/cromwell/docker/DockerInfoActorSpec.scala @@ -2,6 +2,7 @@ package cromwell.docker import cromwell.core.Tags.IntegrationTest import cromwell.docker.DockerInfoActor._ +import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry import cromwell.docker.registryv2.flows.dockerhub.DockerHubRegistry import cromwell.docker.registryv2.flows.google.GoogleRegistry import cromwell.docker.registryv2.flows.quay.QuayRegistry @@ -18,7 +19,8 @@ class DockerInfoActorSpec extends DockerRegistrySpec with AnyFlatSpecLike with M override protected lazy val registryFlows = List( new DockerHubRegistry(DockerRegistryConfig.default), new GoogleRegistry(DockerRegistryConfig.default), - new QuayRegistry(DockerRegistryConfig.default) + new QuayRegistry(DockerRegistryConfig.default), + new AzureContainerRegistry(DockerRegistryConfig.default) ) it should "retrieve a public docker hash" taggedAs IntegrationTest in { @@ -50,6 +52,16 @@ class DockerInfoActorSpec extends DockerRegistrySpec with AnyFlatSpecLike with M hash should not be empty } } + + it should "retrieve a private docker hash on acr" taggedAs IntegrationTest in { + dockerActor ! makeRequest("terrabatchdev.azurecr.io/postgres:latest") + + expectMsgPF(15 second) { + case DockerInfoSuccessResponse(DockerInformation(DockerHashResult(alg, hash), _), _) => + alg shouldBe "sha256" + hash should not be empty + } + } it should "send image not found message back if the image does not exist" taggedAs IntegrationTest in { val notFound = makeRequest("ubuntu:nonexistingtag") diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala index e50446ea294..ed22d1b55f5 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala @@ -6,7 +6,7 @@ import com.azure.storage.blob.sas.{BlobContainerSasPermission, BlobServiceSasSig import com.typesafe.config.Config import com.typesafe.scalalogging.LazyLogging import common.validation.Validation._ -import cromwell.cloudsupport.azure.AzureUtils +import cromwell.cloudsupport.azure.{AzureCredentials, AzureUtils} import java.net.URI import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems} From 9a140549636114431de8261d8f3169a27a16968b Mon Sep 17 00:00:00 2001 From: dspeck1 Date: Wed, 9 Aug 2023 17:47:45 -0500 Subject: [PATCH 02/25] WX-1179 Enable GCP Batch Integration Tests (#7199) Co-authored-by: Adam Nichols Co-authored-by: Adam Nichols --- .github/workflows/integration_tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 98185b3d1b3..ebafe51064c 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -26,9 +26,9 @@ jobs: matrix: # Batch test fixes to land later include: - # - build_type: centaurGcpBatch - # build_mysql: 5.7 - # friendly_name: Centaur GCP Batch with MySQL 5.7 + - build_type: centaurGcpBatch + build_mysql: 5.7 + friendly_name: Centaur GCP Batch with MySQL 5.7 - build_type: centaurPapiV2beta build_mysql: 5.7 friendly_name: Centaur Papi V2 Beta with MySQL 5.7 From acf584099b68c35083bffa1815404b36bc7c363e Mon Sep 17 00:00:00 2001 From: dspeck1 Date: Fri, 11 Aug 2023 13:14:11 -0500 Subject: [PATCH 03/25] WX-1179 GCP Batch Docs Update (#7196) Co-authored-by: Jacob Jennings <53196131+jakejennings@users.noreply.github.com> Co-authored-by: Adam Nichols Co-authored-by: Adam Nichols --- docs/backends/GCPBatch.md | 68 +++++++-------------------------------- 1 file changed, 11 insertions(+), 57 deletions(-) diff --git a/docs/backends/GCPBatch.md b/docs/backends/GCPBatch.md index 766719af8c5..d626356223f 100644 --- a/docs/backends/GCPBatch.md +++ b/docs/backends/GCPBatch.md @@ -3,9 +3,9 @@ [//]: Google Cloud Batch is a fully managed service that lets you schedule, queue, and execute batch processing workloads on Google Cloud resources. Batch provisions resources and manages capacity on your behalf, allowing your batch workloads to run at scale. -This section offers detailed configuration instructions for using Cromwell with the Batch API in all supported +This section offers detailed configuration instructions for using Cromwell with the Google Cloud Batch in all supported authentication modes. Before reading further in this section please see the -[Getting started on Google Batch API](../tutorials/Batch101) for instructions common to all authentication modes +[Getting started on Google Cloud Batch](../tutorials/Batch101) for instructions common to all authentication modes and detailed instructions for the application default authentication scheme in particular. The instructions below assume you have created a Google Cloud Storage bucket and a Google project enabled for the appropriate APIs. @@ -90,7 +90,6 @@ While technically not part of Service Account authentication mode, one can also A [JSON key file for the service account](../wf_options/Google.md) must be passed in via the `user_service_account_json` field in the [Workflow Options](../wf_options/Google.md) when submitting the job. Omitting this field will cause the workflow to fail. The JSON should be passed as a string and will need to have no newlines and all instances of `"` and `\n` escaped. -[//]: # (TODO: is jes_gcs_root the correct workflow option?) In the likely event that this service account does not have access to Cromwell's default google project the `google_project` workflow option must be set. In the similarly likely case that this service account can not access Cromwell's default google bucket, the `jes_gcs_root` workflow option should be set appropriately. For information on the interaction of `user_service_account_json` with private Docker images please see the `Docker` section below. @@ -113,13 +112,11 @@ task mytask { } ``` -In order for a private image to be used the appropriate Docker configuration must be provided. If the Docker images being used +In order for a private image to be used, Docker Hub credentials must be provided. If the Docker images being used are public there is no need to add this configuration. For Batch -[//]: # (TODO: Is this the correct way to configure Docker for batch?) -[//]: # (5-4-23: Leave alone for now) ``` backend { default = GCPBATCH @@ -129,8 +126,6 @@ backend { config { dockerhub { token = "base64-encoded-docker-hub-username:password" - key-name = "name/of/the/kms/key/used/for/encrypting/and/decrypting/the/docker/hub/token" - auth = "reference-to-the-auth-cromwell-should-use-for-kms-encryption" } } } @@ -140,42 +135,6 @@ backend { `token` is the standard base64-encoded username:password for the appropriate Docker Hub account. -`key-name` is the name of the Google KMS key Cromwell should use for encrypting the Docker `token` before including it -in the PAPI job execution request. This `key-name` will also be included in the PAPI job execution -request and will be used by Batch to decrypt the Docker token used by `docker login` to enable access to the private Docker image. - -`auth` is a reference to the name of an authorization in the `auths` block of Cromwell's `google` config. -Cromwell will use this authorization for encrypting the Google KMS key. - -The equivalents of `key-name`, `token` and `auth` can also be specified in workflow options which take -precedence over values specified in configuration. The corresponding workflow options are named `docker_credentials_key_name`, -`docker_credentials_token`, and `user_service_account_json`. While the config value `auth` refers to an auth defined in the -`google.auths` stanza elsewhere in Cromwell's -configuration, `user_service_account_json` is expected to be a literal escaped Google service account auth JSON. -See the `User Service Account` section above for more information on using user service accounts. -If the key, token or auth value is provided in workflow options then the corresponding private Docker configuration value -is not required, and vice versa. Also note that for the `user_service_account_json` workflow option to work an auth of type `user_service_account` -must be defined in Cromwell's `google.auths` stanza; more details in the `User Service Account` section above. - -Example Batch workflow options for private Docker configuration: - -``` -{ - "docker_credentials_key_name": "name/of/the/kms/key/used/for/encrypting/and/decrypting/the/docker/hub/token", - "docker_credentials_token": "base64_username:password", - "user_service_account_json": "" -} -``` - -Important - -If any of the three private Docker configuration values of key name, auth, or Docker token are missing, Batch will not perform a `docker login`. -If the Docker image to be pulled is not public the `docker pull` will fail which will cause the overall job to fail. - -If using any of these private Docker workflow options it is advisable to add -them to the `workflow-options.encrypted-fields` list in Cromwell configuration. - - **Monitoring** In order to monitor metrics (CPU, Memory, Disk usage...) about the VM during Call Runtime, a workflow option can be used to specify the path to a script that will run in the background and write its output to a log file. @@ -207,7 +166,7 @@ backend.providers.GCPBATCH.config { #### Google Labels -Every call run on the Batch API backend is given certain labels by default, so that Google resources can be queried by these labels later. +Every call run on the GCP Batch backend is given certain labels by default, so that Google resources can be queried by these labels later. The current default label set automatically applied is: | Key | Value | Example | Notes | @@ -217,7 +176,7 @@ The current default label set automatically applied is: | wdl-task-name | The name of the WDL task | my-task | | | wdl-call-alias | The alias of the WDL call that created this job | my-task-1 | Only present if the task was called with an alias. | -Any custom labels provided as '`google_labels`' in the [workflow options](../wf_options/Google) are also applied to Google resources by the Batch API. +Any custom labels provided as '`google_labels`' in the [workflow options](../wf_options/Google) are also applied to Google resources by GCP Batch. ### Virtual Private Network @@ -257,12 +216,12 @@ configuration key, which is `vpc-network` here, as the name of private network a If the network name is not present in the config Cromwell will fall back to trying to run jobs on the default network. If the `network-name` or `subnetwork-name` values contain the string `${projectId}` then that value will be replaced -by Cromwell with the name of the project running the Batch API. +by Cromwell with the name of the project running GCP Batch. If the `network-name` does not contain a `/` then it will be prefixed with `projects/${projectId}/global/networks/`. -Cromwell will then pass the network and subnetwork values to the Batch API. See the documentation for the -[Batch API](https://cloud.google.com/batch/docs/networking-overview) +Cromwell will then pass the network and subnetwork values to GCP Batch. See the documentation for +[GCP Batch](https://cloud.google.com/batch/docs/networking-overview) for more information on the various formats accepted for `network` and `subnetwork`. #### Virtual Private Network via Labels @@ -306,7 +265,6 @@ network labels, and then fall back to running on the default network. ### Custom Google Cloud SDK container -[//]: # (TODO: need to test this section as well) Cromwell can't use Google's container registry if VPC Perimeter is used in project. Own repository can be used by adding `cloud-sdk-image-url` reference to used container: @@ -320,8 +278,6 @@ google { ### Parallel Composite Uploads -[//]: # (TODO: Need to test parallel composite uploads) - Cromwell can be configured to use GCS parallel composite uploads which can greatly improve delocalization performance. This feature is turned off by default but can be enabled backend-wide by specifying a `gsutil`-compatible memory specification for the key `genomics.parallel-composite-upload-threshold` in backend configuration. This memory value represents the minimum size an output file @@ -394,20 +350,18 @@ outputs. Calls which are executed and not cached will always honor the parallel their execution. -### Migration from Google Cloud Genomics v2alpha1 to Google Cloud Life Sciences v2beta +### Migration from Google Cloud Life Sciences v2beta to Google Cloud Batch -1. If you currently run your workflows using Cloud Genomics v2beta and would like to switch to Google Batch, you will need to do a few changes to your configuration file: `actor-factory` value should be changed +1. If you currently run your workflows using Cloud Genomics v2beta and would like to switch to Google Cloud Batch, you will need to do a few changes to your configuration file: `actor-factory` value should be changed from `cromwell.backend.google.pipelines.v2beta.PipelinesApiLifecycleActorFactory` to `cromwell.backend.google.batch.GcpBatchLifecycleActorFactory`. 2. You will need to remove the parameter `genomics.endpoint-url` and generate a new config file. -3. Google Batch is now available in a variety of regions. Please see the [Batch Locations](https://cloud.google.com/batch/docs/locations) for a list of supported regions +3. Google Cloud Batch is now available in a variety of regions. Please see the [Batch Locations](https://cloud.google.com/batch/docs/locations) for a list of supported regions ### Reference Disk Support -[//]: # (TODO: follow up later) - Cromwell 55 and later support mounting reference disks from prebuilt GCP disk images as an alternative to localizing large input reference files on Batch. Please note the configuration of reference disk manifests has changed starting with Cromwell 57 and now uses the format documented below. From 29d3810988ece6ccc85feed631bca6b881c7b7d5 Mon Sep 17 00:00:00 2001 From: Trevyn Langsford Date: Mon, 14 Aug 2023 11:41:27 -0400 Subject: [PATCH 04/25] ID-734 Increase Timeout for DRSHub Communication (#7198) --- .../scala/cloud/nio/impl/drs/DrsConfig.scala | 6 ++--- .../cloud/nio/impl/drs/DrsPathResolver.scala | 24 +++++++++++++------ .../nio/impl/drs/EngineDrsPathResolver.scala | 2 +- docs/filesystems/Filesystems.md | 8 +++---- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsConfig.scala b/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsConfig.scala index c8333a57a66..a2b0a385680 100644 --- a/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsConfig.scala +++ b/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsConfig.scala @@ -17,9 +17,9 @@ final case class DrsConfig(drsResolverUrl: String, object DrsConfig { // If you update these values also update Filesystems.md! private val DefaultNumRetries = 3 - private val DefaultWaitInitial = 10 seconds - private val DefaultWaitMaximum = 30 seconds - private val DefaultWaitMultiplier = 1.5d + private val DefaultWaitInitial = 30 seconds + private val DefaultWaitMaximum = 60 seconds + private val DefaultWaitMultiplier = 1.25d private val DefaultWaitRandomizationFactor = 0.1 private val EnvDrsResolverUrl = "DRS_RESOLVER_URL" diff --git a/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsPathResolver.scala b/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsPathResolver.scala index f9ae5b62e03..22d86c31726 100644 --- a/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsPathResolver.scala +++ b/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/DrsPathResolver.scala @@ -17,6 +17,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.http.client.methods.{HttpGet, HttpPost} import org.apache.http.entity.{ContentType, StringEntity} import org.apache.http.impl.client.HttpClientBuilder +import org.apache.http.impl.conn.PoolingHttpClientConnectionManager import org.apache.http.util.EntityUtils import org.apache.http.{HttpResponse, HttpStatus, StatusLine} @@ -24,16 +25,16 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, ReadableByteChannel} import scala.util.Try -abstract class DrsPathResolver(drsConfig: DrsConfig, retryInternally: Boolean = true) { +abstract class DrsPathResolver(drsConfig: DrsConfig) { protected lazy val httpClientBuilder: HttpClientBuilder = { val clientBuilder = HttpClientBuilder.create() - if (retryInternally) { - val retryHandler = new DrsResolverHttpRequestRetryStrategy(drsConfig) - clientBuilder - .setRetryHandler(retryHandler) - .setServiceUnavailableRetryStrategy(retryHandler) - } + val retryHandler = new DrsResolverHttpRequestRetryStrategy(drsConfig) + clientBuilder + .setRetryHandler(retryHandler) + .setServiceUnavailableRetryStrategy(retryHandler) + clientBuilder.setConnectionManager(connectionManager) + clientBuilder.setConnectionManagerShared(true) clientBuilder } @@ -241,4 +242,13 @@ object DrsResolverResponseSupport { baseMessage + "(empty response)" } } + + lazy val connectionManager = { + val connManager = new PoolingHttpClientConnectionManager() + connManager.setMaxTotal(250) + // Since the HttpClient is always talking to DRSHub, + // make the max connections per route the same as max total connections + connManager.setDefaultMaxPerRoute(250) + connManager + } } diff --git a/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/EngineDrsPathResolver.scala b/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/EngineDrsPathResolver.scala index a62ce7971c2..01f7a488eb3 100644 --- a/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/EngineDrsPathResolver.scala +++ b/cloud-nio/cloud-nio-impl-drs/src/main/scala/cloud/nio/impl/drs/EngineDrsPathResolver.scala @@ -5,7 +5,7 @@ import common.validation.ErrorOr.ErrorOr case class EngineDrsPathResolver(drsConfig: DrsConfig, drsCredentials: DrsCredentials, ) - extends DrsPathResolver(drsConfig, retryInternally = false) { + extends DrsPathResolver(drsConfig) { override def getAccessToken: ErrorOr[String] = drsCredentials.getAccessToken } diff --git a/docs/filesystems/Filesystems.md b/docs/filesystems/Filesystems.md index 0630e421c0b..8ca03825c6e 100644 --- a/docs/filesystems/Filesystems.md +++ b/docs/filesystems/Filesystems.md @@ -23,12 +23,12 @@ filesystems { # The number of times to retry failures connecting or HTTP 429 or HTTP 5XX responses, default 3. num-retries = 3 # How long to wait between retrying HTTP 429 or HTTP 5XX responses, default 10 seconds. - wait-initial = 10 seconds + wait-initial = 30 seconds # The maximum amount of time to wait between retrying HTTP 429 or HTTP 5XX responses, default 30 seconds. - wait-maximum = 30 seconds + wait-maximum = 60 seconds # The amount to multiply the amount of time to wait between retrying HTTP or 429 or HTTP 5XX responses. - # Default 2.0, and will never multiply the wait time more than wait-maximum. - wait-mulitiplier = 2.0 + # Default 1.25, and will never multiply the wait time more than wait-maximum. + wait-mulitiplier = 1.25 # The randomization factor to use for creating a range around the wait interval. # A randomization factor of 0.5 results in a random period ranging between 50% below and 50% above the wait # interval. Default 0.1. From 2a69691ec56ba0e8f279b8f006ba796bb9cfaf05 Mon Sep 17 00:00:00 2001 From: Tom Wiseman Date: Tue, 15 Aug 2023 15:01:14 -0400 Subject: [PATCH 05/25] [WX-1156] internal_path_prefix for TES 4.4 (#7190) --- .../backend/impl/tes/TesJobPaths.scala | 9 +++ .../cromwell/backend/impl/tes/TesTask.scala | 34 +++++----- .../backend/impl/tes/TesTaskSpec.scala | 63 +++++++++++++++---- 3 files changed, 81 insertions(+), 25 deletions(-) diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala index 0b75cfc0f3a..f1797d9bc58 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala @@ -28,6 +28,15 @@ case class TesJobPaths private[tes] (override val workflowPaths: TesWorkflowPath val callInputsDockerRoot = callDockerRoot.resolve("inputs") val callInputsRoot = callRoot.resolve("inputs") + /* + * tesTaskRoot: This is the root directory that TES will use for files related to this task. + * We provide it to TES as a k/v pair where the key is "internal_path_prefix" (specified in TesWorkflowOptionKeys.scala) + * and the value is a blob path. + * This is not a standard TES feature, but rather related to the Azure TES implementation that Terra uses. + * While passing it outside of terra won't do any harm, we could consider making this optional and/or configurable. + */ + val tesTaskRoot : Path = callExecutionRoot.resolve("tes_task") + // Given an output path, return a path localized to the storage file system def storageOutput(path: String): String = { callExecutionRoot.resolve(path).toString diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala index f7674c377e7..66f0f508d77 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala @@ -1,5 +1,4 @@ package cromwell.backend.impl.tes - import common.collections.EnhancedCollections._ import common.util.StringUtil._ import cromwell.backend.impl.tes.OutputMode.OutputMode @@ -71,7 +70,6 @@ final case class TesTask(jobDescriptor: BackendJobDescriptor, path = tesPaths.callExecutionDockerRoot.resolve("script").toString, `type` = Option("FILE") ) - private def writeFunctionFiles: Map[FullyQualifiedName, Seq[WomFile]] = instantiatedCommand.createdFiles map { f => f.file.value.md5SumShort -> List(f.file) } toMap @@ -231,11 +229,6 @@ final case class TesTask(jobDescriptor: BackendJobDescriptor, workflowExecutionIdentityOption ) - val resources: Resources = TesTask.makeResources( - runtimeAttributes, - preferedWorkflowExecutionIdentity - ) - val executors = Seq(Executor( image = dockerImageUsed, command = Seq(jobShell, commandScript.path), @@ -245,6 +238,12 @@ final case class TesTask(jobDescriptor: BackendJobDescriptor, stdin = None, env = None )) + + val resources: Resources = TesTask.makeResources( + runtimeAttributes, + preferedWorkflowExecutionIdentity, + Option(tesPaths.tesTaskRoot.pathAsString) + ) } object TesTask { @@ -254,15 +253,22 @@ object TesTask { configIdentity.map(_.value).orElse(workflowOptionsIdentity.map(_.value)) } def makeResources(runtimeAttributes: TesRuntimeAttributes, - workflowExecutionId: Option[String]): Resources = { - - // This was added in BT-409 to let us pass information to an Azure - // TES server about which user identity to run tasks as. - // Note that we validate the type of WorkflowExecutionIdentity - // in TesInitializationActor. - val backendParameters = runtimeAttributes.backendParameters ++ + workflowExecutionId: Option[String], internalPathPrefix: Option[String]): Resources = { + /* + * workflowExecutionId: This was added in BT-409 to let us pass information to an Azure + * TES server about which user identity to run tasks as. + * Note that we validate the type of WorkflowExecutionIdentity in TesInitializationActor. + * + * internalPathPrefix: Added in WX-1156 to support the azure TES implementation. Specifies + * a working directory that the TES task can use. + */ + val internalPathPrefixKey = "internal_path_prefix" + val backendParameters : Map[String, Option[String]] = runtimeAttributes.backendParameters ++ workflowExecutionId .map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option(_)) + .toMap ++ + internalPathPrefix + .map(internalPathPrefixKey -> Option(_)) .toMap val disk :: ram :: _ = Seq(runtimeAttributes.disk, runtimeAttributes.memory) map { case Some(x) => diff --git a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala index e2c743bb718..5bfa916086d 100644 --- a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala +++ b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala @@ -31,34 +31,40 @@ class TesTaskSpec false, Map.empty ) + val internalPathPrefix = Option("mock/path/to/tes/task") + val expectedTuple = "internal_path_prefix" -> internalPathPrefix it should "create the correct resources when an identity is passed in WorkflowOptions" in { val wei = Option("abc123") - TesTask.makeResources(runtimeAttributes, wei) shouldEqual - Resources(None, None, None, Option(false), None, Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("abc123"))) + TesTask.makeResources(runtimeAttributes, wei, internalPathPrefix) shouldEqual + Resources(None, None, None, Option(false), None, + Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("abc123"), + expectedTuple)) ) } it should "create the correct resources when an empty identity is passed in WorkflowOptions" in { val wei = Option("") - TesTask.makeResources(runtimeAttributes, wei) shouldEqual - Resources(None, None, None, Option(false), None, Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option(""))) + TesTask.makeResources(runtimeAttributes, wei, internalPathPrefix) shouldEqual + Resources(None, None, None, Option(false), None, + Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option(""), + expectedTuple)) ) } it should "create the correct resources when no identity is passed in WorkflowOptions" in { val wei = None - TesTask.makeResources(runtimeAttributes, wei) shouldEqual - Resources(None, None, None, Option(false), None, Option(Map.empty[String, Option[String]]) - ) + TesTask.makeResources(runtimeAttributes, wei, internalPathPrefix) shouldEqual + Resources(None, None, None, Option(false), None, Option(Map(expectedTuple))) } it should "create the correct resources when an identity is passed in via backend config" in { val weic = Option(WorkflowExecutionIdentityConfig("abc123")) val weio = Option(WorkflowExecutionIdentityOption("def456")) val wei = TesTask.getPreferredWorkflowExecutionIdentity(weic, weio) - TesTask.makeResources(runtimeAttributes, wei) shouldEqual - Resources(None, None, None, Option(false), None, Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("abc123"))) + TesTask.makeResources(runtimeAttributes, wei, internalPathPrefix) shouldEqual + Resources(None, None, None, Option(false), None, Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("abc123"), + expectedTuple)) ) } @@ -66,11 +72,46 @@ class TesTaskSpec val weic = None val weio = Option(WorkflowExecutionIdentityOption("def456")) val wei = TesTask.getPreferredWorkflowExecutionIdentity(weic, weio) - TesTask.makeResources(runtimeAttributes, wei) shouldEqual - Resources(None, None, None, Option(false), None, Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("def456"))) + TesTask.makeResources(runtimeAttributes, wei, internalPathPrefix) shouldEqual + Resources(None, None, None, Option(false), None, Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("def456"), + expectedTuple)) ) } + it should "correctly set the internal path prefix when provided as a backend parameter" in { + val wei = Option("abc123") + val internalPathPrefix = Option("mock/path/to/tes/task") + TesTask.makeResources(runtimeAttributes, wei, internalPathPrefix) shouldEqual + Resources(None, None, None, Option(false), None, + Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("abc123"), + "internal_path_prefix" -> internalPathPrefix) + )) + } + + it should "correctly resolve the path to .../tes_task and add the k/v pair to backend parameters" in { + val emptyWorkflowOptions = WorkflowOptions(JsObject(Map.empty[String, JsValue])) + val workflowDescriptor = buildWdlWorkflowDescriptor(TestWorkflows.HelloWorld, + labels = Labels("foo" -> "bar")) + val jobDescriptor = jobDescriptorFromSingleCallWorkflow(workflowDescriptor, + Map.empty, + emptyWorkflowOptions, + Set.empty) + val tesPaths = TesJobPaths(jobDescriptor.key, + jobDescriptor.workflowDescriptor, + TestConfig.emptyConfig) + + val expectedKey = "internal_path_prefix" + val expectedValue = Option(tesPaths.tesTaskRoot.pathAsString) + + //Assert path correctly ends up in the resources + val wei = Option("abc123") + TesTask.makeResources(runtimeAttributes, wei, expectedValue) shouldEqual + Resources(None, None, None, Option(false), None, + Option(Map(TesWorkflowOptionKeys.WorkflowExecutionIdentity -> Option("abc123"), + expectedKey -> expectedValue)) + ) + } + it should "copy labels to tags" in { val jobLogger = mock[JobLogger] val emptyWorkflowOptions = WorkflowOptions(JsObject(Map.empty[String, JsValue])) From 2fc467e2be4ba7930a386ea43f12b06432d60c29 Mon Sep 17 00:00:00 2001 From: Dillon Scott <67511512+dillydally414@users.noreply.github.com> Date: Thu, 17 Aug 2023 15:51:28 -0400 Subject: [PATCH 06/25] [WM-2184] Remove brackets from Jira ID (#7206) --- .github/workflows/chart_update_on_merge.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/chart_update_on_merge.yml b/.github/workflows/chart_update_on_merge.yml index 00995f5178b..a2b14f2ec65 100644 --- a/.github/workflows/chart_update_on_merge.yml +++ b/.github/workflows/chart_update_on_merge.yml @@ -14,7 +14,7 @@ jobs: - name: Fetch Jira ID from the commit message id: fetch-jira-id run: | - JIRA_ID=$(echo '${{ github.event.pull_request.title }}' | grep -Eo '\[?[A-Z][A-Z]+-[0-9]+\]?') + JIRA_ID=$(echo '${{ github.event.pull_request.title }}' | grep -Eo '[A-Z][A-Z]+-[0-9]+' | xargs echo -n | tr '[:space:]' ',') [[ -z "$JIRA_ID" ]] && { echo "No Jira ID found in $1" ; exit 1; } echo "JIRA_ID=$JIRA_ID" >> $GITHUB_OUTPUT - name: Clone Cromwell From c025756c349e45f1d7cd4939869106368d217d17 Mon Sep 17 00:00:00 2001 From: Adam Nichols Date: Thu, 17 Aug 2023 18:00:49 -0400 Subject: [PATCH 07/25] WX-1153 Azure blob read md5 from metadata for large files (#7204) Co-authored-by: Janet Gainer-Dewar --- .github/workflows/cromwell_unit_tests.yml | 4 + .../blob/BlobFileSystemManager.scala | 11 +- .../filesystems/blob/BlobPathBuilder.scala | 48 +++++++-- .../blob/BlobPathBuilderSpec.scala | 100 ++++++++++++------ 4 files changed, 117 insertions(+), 46 deletions(-) diff --git a/.github/workflows/cromwell_unit_tests.yml b/.github/workflows/cromwell_unit_tests.yml index 797f38efd96..88951871d8f 100644 --- a/.github/workflows/cromwell_unit_tests.yml +++ b/.github/workflows/cromwell_unit_tests.yml @@ -28,6 +28,10 @@ jobs: #Invoke SBT to run all unit tests for Cromwell. - name: Run tests + env: + AZURE_CLIENT_ID: ${{ secrets.VAULT_AZURE_CENTAUR_CLIENT_ID }} + AZURE_CLIENT_SECRET: ${{ secrets.VAULT_AZURE_CENTAUR_CLIENT_SECRET }} + AZURE_TENANT_ID: ${{ secrets.VAULT_AZURE_CENTAUR_TENANT_ID }} run: | set -e sbt "test" diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala index ed22d1b55f5..ac8f01d2cc7 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala @@ -1,7 +1,7 @@ package cromwell.filesystems.blob import com.azure.core.credential.AzureSasCredential -import com.azure.storage.blob.nio.AzureFileSystem +import com.azure.storage.blob.nio.{AzureFileSystem, AzureFileSystemProvider} import com.azure.storage.blob.sas.{BlobContainerSasPermission, BlobServiceSasSignatureValues} import com.typesafe.config.Config import com.typesafe.scalalogging.LazyLogging @@ -9,7 +9,8 @@ import common.validation.Validation._ import cromwell.cloudsupport.azure.{AzureCredentials, AzureUtils} import java.net.URI -import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems} +import java.nio.file.spi.FileSystemProvider +import java.nio.file.{FileSystem, FileSystemNotFoundException} import java.time.temporal.ChronoUnit import java.time.{Duration, Instant, OffsetDateTime} import scala.jdk.CollectionConverters._ @@ -17,9 +18,9 @@ import scala.util.{Failure, Success, Try} // We encapsulate this functionality here so that we can easily mock it out, to allow for testing without // actually connecting to Blob storage. -case class FileSystemAPI() { - def getFileSystem(uri: URI): Try[FileSystem] = Try(FileSystems.getFileSystem(uri)) - def newFileSystem(uri: URI, config: Map[String, Object]): FileSystem = FileSystems.newFileSystem(uri, config.asJava) +case class FileSystemAPI(private val provider: FileSystemProvider = new AzureFileSystemProvider()) { + def getFileSystem(uri: URI): Try[FileSystem] = Try(provider.getFileSystem(uri)) + def newFileSystem(uri: URI, config: Map[String, Object]): FileSystem = provider.newFileSystem(uri, config.asJava) def closeFileSystem(uri: URI): Option[Unit] = getFileSystem(uri).toOption.map(_.close) } /** diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala index 9e7b230286c..aa6445fa779 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala @@ -7,6 +7,7 @@ import cromwell.filesystems.blob.BlobPathBuilder._ import java.net.{MalformedURLException, URI} import java.nio.file.{Files, LinkOption} +import scala.jdk.CollectionConverters._ import scala.language.postfixOps import scala.util.{Failure, Success, Try} @@ -78,6 +79,20 @@ object BlobPath { // format the library expects // 2) If the path looks like :, strip off the : to leave the absolute path inside the container. private val brokenPathRegex = "https:/([a-z0-9]+).blob.core.windows.net/([-a-zA-Z0-9]+)/(.*)".r + + // Blob files larger than 5 GB upload in parallel parts [0][1] and do not get a native `CONTENT-MD5` property. + // Instead, some uploaders such as TES [2] calculate the md5 themselves and store it under this key in metadata. + // They do this for all files they touch, regardless of size, and the root/metadata property is authoritative over native. + // + // N.B. most if not virtually all large files in the wild will NOT have this key populated because they were not created + // by TES or its associated upload utility [4]. + // + // [0] https://learn.microsoft.com/en-us/azure/storage/blobs/scalability-targets + // [1] https://learn.microsoft.com/en-us/rest/api/storageservices/version-2019-12-12 + // [2] https://github.com/microsoft/ga4gh-tes/blob/03feb746bb961b72fa91266a56db845e3b31be27/src/Tes.Runner/Transfer/BlobBlockApiHttpUtils.cs#L25 + // [4] https://github.com/microsoft/ga4gh-tes/blob/main/src/Tes.RunnerCLI/scripts/roothash.sh + private val largeBlobFileMetadataKey = "md5_4mib_hashlist_root_hash" + def cleanedNioPathString(nioString: String): String = { val pathStr = nioString match { case brokenPathRegex(_, containerName, pathInContainer) => @@ -116,16 +131,33 @@ case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, con def blobFileAttributes: Try[AzureBlobFileAttributes] = Try(Files.readAttributes(nioPath, classOf[AzureBlobFileAttributes])) + def blobFileMetadata: Try[Option[Map[String, String]]] = blobFileAttributes.map { attrs => + // `metadata()` has a documented `null` case + Option(attrs.metadata()).map(_.asScala.toMap) + } + def md5HexString: Try[Option[String]] = { - blobFileAttributes.map(h => - Option(h.blobHttpHeaders().getContentMd5) match { - case None => None - case Some(arr) if arr.isEmpty => None - // Convert the bytes to a hex-encoded string. Note that this value - // is rendered in base64 in the Azure web portal. - case Some(bytes) => Option(bytes.map("%02x".format(_)).mkString) + def md5FromMetadata: Option[String] = (blobFileMetadata map { maybeMetadataMap: Option[Map[String, String]] => + maybeMetadataMap flatMap { metadataMap: Map[String, String] => + metadataMap.get(BlobPath.largeBlobFileMetadataKey) } - ) + }).toOption.flatten + + // Convert the bytes to a hex-encoded string. Note that the value + // is rendered in base64 in the Azure web portal. + def hexString(bytes: Array[Byte]): String = bytes.map("%02x".format(_)).mkString + + blobFileAttributes.map { attr: AzureBlobFileAttributes => + (Option(attr.blobHttpHeaders().getContentMd5), md5FromMetadata) match { + case (None, None) => None + // (Some, Some) will happen for all <5 GB files uploaded by TES. Per Microsoft 2023-08-15 the + // root/metadata algorithm emits different values than the native algorithm and we should + // always choose metadata for consistency with larger files that only have that one. + case (_, Some(metadataMd5)) => Option(metadataMd5) + case (Some(headerMd5Bytes), None) if headerMd5Bytes.isEmpty => None + case (Some(headerMd5Bytes), None) => Option(hexString(headerMd5Bytes)) + } + } } /** diff --git a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala index 4012e241eb3..eef6db8e942 100644 --- a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala +++ b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala @@ -89,46 +89,80 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { ) } - //// The below tests are IGNORED because they depend on Azure auth information being present in the environment //// + // The following tests use the `centaurtesting` account injected into CI. They depend on access to the + // container specified below. You may need to log in to az cli locally to get them to pass. private val subscriptionId: SubscriptionId = SubscriptionId(UUID.fromString("62b22893-6bc1-46d9-8a90-806bb3cce3c9")) - private val endpoint: EndpointURL = BlobPathBuilderSpec.buildEndpoint("coaexternalstorage") - private val store: BlobContainerName = BlobContainerName("inputs") + private val endpoint: EndpointURL = BlobPathBuilderSpec.buildEndpoint("centaurtesting") + private val container: BlobContainerName = BlobContainerName("test-blob") - def makeBlobPathBuilder(blobEndpoint: EndpointURL, container: BlobContainerName): BlobPathBuilder = { + def makeBlobPathBuilder(blobEndpoint: EndpointURL, + container: BlobContainerName): BlobPathBuilder = { val blobTokenGenerator = NativeBlobSasTokenGenerator(container, blobEndpoint, Some(subscriptionId)) val fsm = new BlobFileSystemManager(container, blobEndpoint, 10, blobTokenGenerator) - new BlobPathBuilder(store, endpoint)(fsm) + new BlobPathBuilder(container, blobEndpoint)(fsm) } - ignore should "resolve an absolute path string correctly to a path" in { - val builder = makeBlobPathBuilder(endpoint, store) - val rootString = s"${endpoint.value}/${store.value}/cromwell-execution" + it should "read md5 from small files <5g" in { + val builder = makeBlobPathBuilder(endpoint, container) + val evalPath = "/testRead.txt" + val testString = endpoint.value + "/" + container + evalPath + val blobPath1: BlobPath = (builder build testString).get + blobPath1.md5HexString.get should equal(Option("31ae06882d06a20e01ba1ac961ce576c")) + } + + it should "read md5 from large files >5g" in { + val builder = makeBlobPathBuilder(endpoint, container) + val evalPath = "/Rocky-9.2-aarch64-dvd.iso" + val testString = endpoint.value + "/" + container + evalPath + val blobPath1: BlobPath = (builder build testString).get + blobPath1.md5HexString.toOption.get should equal(Some("13cb09331d2d12c0f476f81c672a4319")) + } + + it should "choose the root/metadata md5 over the native md5 for files that have both" in { + val builder = makeBlobPathBuilder(endpoint, container) + val evalPath = "/redundant_md5_test.txt" + val testString = endpoint.value + "/" + container + evalPath + val blobPath1: BlobPath = (builder build testString).get + blobPath1.md5HexString.toOption.get should equal(Some("021c7cc715ec82292bb9b925f9ca44d3")) + } + + it should "gracefully return `None` when neither hash is found" in { + val builder = makeBlobPathBuilder(endpoint, container) + val evalPath = "/no_md5_test.txt" + val testString = endpoint.value + "/" + container + evalPath + val blobPath1: BlobPath = (builder build testString).get + blobPath1.md5HexString.get should equal(None) + } + + it should "resolve an absolute path string correctly to a path" in { + val builder = makeBlobPathBuilder(endpoint, container) + val rootString = s"${endpoint.value}/${container.value}/cromwell-execution" val blobRoot: BlobPath = builder build rootString getOrElse fail() - blobRoot.toAbsolutePath.pathAsString should equal ("https://coaexternalstorage.blob.core.windows.net/inputs/cromwell-execution") - val otherFile = blobRoot.resolve("https://coaexternalstorage.blob.core.windows.net/inputs/cromwell-execution/test/inputFile.txt") - otherFile.toAbsolutePath.pathAsString should equal ("https://coaexternalstorage.blob.core.windows.net/inputs/cromwell-execution/test/inputFile.txt") + blobRoot.toAbsolutePath.pathAsString should equal ("https://centaurtesting.blob.core.windows.net/test-blob/cromwell-execution") + val otherFile = blobRoot.resolve("https://centaurtesting.blob.core.windows.net/test-blob/cromwell-execution/test/inputFile.txt") + otherFile.toAbsolutePath.pathAsString should equal ("https://centaurtesting.blob.core.windows.net/test-blob/cromwell-execution/test/inputFile.txt") } - ignore should "build a blob path from a test string and read a file" in { - val builder = makeBlobPathBuilder(endpoint, store) + it should "build a blob path from a test string and read a file" in { + val builder = makeBlobPathBuilder(endpoint, container) val endpointHost = BlobPathBuilder.parseURI(endpoint.value).map(_.getHost).getOrElse(fail("Could not parse URI")) val evalPath = "/test/inputFile.txt" - val testString = endpoint.value + "/" + store + evalPath + val testString = endpoint.value + "/" + container + evalPath val blobPath: BlobPath = builder build testString getOrElse fail() - blobPath.container should equal(store) + blobPath.container should equal(container) blobPath.endpoint should equal(endpoint) blobPath.pathAsString should equal(testString) - blobPath.pathWithoutScheme should equal(endpointHost + "/" + store + evalPath) + blobPath.pathWithoutScheme should equal(endpointHost + "/" + container + evalPath) val is = blobPath.newInputStream() val fileText = (is.readAllBytes.map(_.toChar)).mkString fileText should include ("This is my test file!!!! Did it work?") } - ignore should "build duplicate blob paths in the same filesystem" in { - val builder = makeBlobPathBuilder(endpoint, store) + it should "build duplicate blob paths in the same filesystem" in { + val builder = makeBlobPathBuilder(endpoint, container) val evalPath = "/test/inputFile.txt" - val testString = endpoint.value + "/" + store + evalPath + val testString = endpoint.value + "/" + container + evalPath val blobPath1: BlobPath = builder build testString getOrElse fail() blobPath1.nioPath.getFileSystem.close() val blobPath2: BlobPath = builder build testString getOrElse fail() @@ -138,20 +172,20 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { fileText should include ("This is my test file!!!! Did it work?") } - ignore should "resolve a path without duplicating container name" in { - val builder = makeBlobPathBuilder(endpoint, store) - val rootString = s"${endpoint.value}/${store.value}/cromwell-execution" + it should "resolve a path without duplicating container name" in { + val builder = makeBlobPathBuilder(endpoint, container) + val rootString = s"${endpoint.value}/${container.value}/cromwell-execution" val blobRoot: BlobPath = builder build rootString getOrElse fail() - blobRoot.toAbsolutePath.pathAsString should equal ("https://coaexternalstorage.blob.core.windows.net/inputs/cromwell-execution") + blobRoot.toAbsolutePath.pathAsString should equal ("https://centaurtesting.blob.core.windows.net/test-blob/cromwell-execution") val otherFile = blobRoot.resolve("test/inputFile.txt") - otherFile.toAbsolutePath.pathAsString should equal ("https://coaexternalstorage.blob.core.windows.net/inputs/cromwell-execution/test/inputFile.txt") + otherFile.toAbsolutePath.pathAsString should equal ("https://centaurtesting.blob.core.windows.net/test-blob/cromwell-execution/test/inputFile.txt") } - ignore should "correctly remove a prefix from the blob path" in { - val builder = makeBlobPathBuilder(endpoint, store) - val rootString = s"${endpoint.value}/${store.value}/cromwell-execution/" - val execDirString = s"${endpoint.value}/${store.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/" - val fileString = s"${endpoint.value}/${store.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/stdout" + it should "correctly remove a prefix from the blob path" in { + val builder = makeBlobPathBuilder(endpoint, container) + val rootString = s"${endpoint.value}/${container.value}/cromwell-execution/" + val execDirString = s"${endpoint.value}/${container.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/" + val fileString = s"${endpoint.value}/${container.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/stdout" val blobRoot: BlobPath = builder build rootString getOrElse fail() val execDir: BlobPath = builder build execDirString getOrElse fail() val blobFile: BlobPath = builder build fileString getOrElse fail() @@ -160,10 +194,10 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { blobFile.pathStringWithoutPrefix(blobFile) should equal ("") } - ignore should "not change a path if it doesn't start with a prefix" in { - val builder = makeBlobPathBuilder(endpoint, store) - val otherRootString = s"${endpoint.value}/${store.value}/foobar/" - val fileString = s"${endpoint.value}/${store.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/stdout" + it should "not change a path if it doesn't start with a prefix" in { + val builder = makeBlobPathBuilder(endpoint, container) + val otherRootString = s"${endpoint.value}/${container.value}/foobar/" + val fileString = s"${endpoint.value}/${container.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/stdout" val otherBlobRoot: BlobPath = builder build otherRootString getOrElse fail() val blobFile: BlobPath = builder build fileString getOrElse fail() blobFile.pathStringWithoutPrefix(otherBlobRoot) should equal ("/cromwell-execution/abc123/myworkflow/task1/def4356/execution/stdout") From c12aaba2e082f1dadfdb52c6e78a36764a12808a Mon Sep 17 00:00:00 2001 From: Tom Wiseman Date: Fri, 18 Aug 2023 11:21:16 -0400 Subject: [PATCH 08/25] WX-1156 Fix internal_path_prefix (#7208) Co-authored-by: Janet Gainer-Dewar --- .../cromwell/filesystems/blob/BlobPathBuilder.scala | 8 ++++++++ .../scala/cromwell/backend/impl/tes/TesJobPaths.scala | 11 ++++++++--- .../scala/cromwell/backend/impl/tes/TesTask.scala | 2 +- .../scala/cromwell/backend/impl/tes/TesTaskSpec.scala | 2 +- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala index aa6445fa779..35c518c0a43 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala @@ -174,5 +174,13 @@ case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, con else pathString } + /** + * Returns the path relative to the container root. + * For example, https://{storageAccountName}.blob.core.windows.net/{containerid}/path/to/my/file + * will be returned as path/to/my/file. + * @return Path string relative to the container root. + */ + def pathWithoutContainer : String = pathString + override def getSymlinkSafePath(options: LinkOption*): Path = toAbsolutePath } diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala index f1797d9bc58..a624a16328a 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala @@ -4,6 +4,7 @@ import com.typesafe.config.Config import cromwell.backend.io.{JobPaths, WorkflowPaths} import cromwell.backend.{BackendJobDescriptorKey, BackendWorkflowDescriptor} import cromwell.core.path._ +import cromwell.filesystems.blob.BlobPath object TesJobPaths { def apply(jobKey: BackendJobDescriptorKey, @@ -30,12 +31,16 @@ case class TesJobPaths private[tes] (override val workflowPaths: TesWorkflowPath /* * tesTaskRoot: This is the root directory that TES will use for files related to this task. - * We provide it to TES as a k/v pair where the key is "internal_path_prefix" (specified in TesWorkflowOptionKeys.scala) - * and the value is a blob path. + * TES expects a path relative to the root of the storage container. + * We provide it to TES as a k/v pair where the key is "internal_path_prefix" and the value is the relative path string. * This is not a standard TES feature, but rather related to the Azure TES implementation that Terra uses. * While passing it outside of terra won't do any harm, we could consider making this optional and/or configurable. */ - val tesTaskRoot : Path = callExecutionRoot.resolve("tes_task") + private val taskFullPath = callRoot./("tes_task") + val tesTaskRoot : String = taskFullPath match { + case blob: BlobPath => blob.pathWithoutContainer + case anyOtherPath: Path => anyOtherPath.pathAsString + } // Given an output path, return a path localized to the storage file system def storageOutput(path: String): String = { diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala index 66f0f508d77..a345e87ebf7 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala @@ -242,7 +242,7 @@ final case class TesTask(jobDescriptor: BackendJobDescriptor, val resources: Resources = TesTask.makeResources( runtimeAttributes, preferedWorkflowExecutionIdentity, - Option(tesPaths.tesTaskRoot.pathAsString) + Option(tesPaths.tesTaskRoot) ) } diff --git a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala index 5bfa916086d..25a8f55f682 100644 --- a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala +++ b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala @@ -101,7 +101,7 @@ class TesTaskSpec TestConfig.emptyConfig) val expectedKey = "internal_path_prefix" - val expectedValue = Option(tesPaths.tesTaskRoot.pathAsString) + val expectedValue = Option(tesPaths.tesTaskRoot) //Assert path correctly ends up in the resources val wei = Option("abc123") From 3affdc3541bfe92633b770b1302b056cd2b46f1a Mon Sep 17 00:00:00 2001 From: Janet Gainer-Dewar Date: Mon, 21 Aug 2023 13:05:09 -0400 Subject: [PATCH 09/25] WX-1256 Temporarily turn off engine hashing for blob files (#7209) Co-authored-by: Adam Nichols --- engine/src/main/scala/cromwell/engine/io/nio/NioFlow.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/engine/src/main/scala/cromwell/engine/io/nio/NioFlow.scala b/engine/src/main/scala/cromwell/engine/io/nio/NioFlow.scala index b6ce3ee7cc2..69e5551b8a9 100644 --- a/engine/src/main/scala/cromwell/engine/io/nio/NioFlow.scala +++ b/engine/src/main/scala/cromwell/engine/io/nio/NioFlow.scala @@ -159,7 +159,9 @@ class NioFlow(parallelism: Int, val fileContentIo = command.file match { case _: DrsPath => readFileAndChecksum - case _: BlobPath => readFileAndChecksum + // Temporarily disable since our hashing algorithm doesn't match the stored hash + // https://broadworkbench.atlassian.net/browse/WX-1257 + case _: BlobPath => readFile//readFileAndChecksum case _ => readFile } fileContentIo.map(_.replaceAll("\\r\\n", "\\\n")) From bdc1ab3b6becb4b496e44ea7faa1d69db64a43aa Mon Sep 17 00:00:00 2001 From: Christian Freitas Date: Tue, 22 Aug 2023 10:07:15 -0400 Subject: [PATCH 10/25] WX-1173 Reopen filesystem for blob storage outside workspace (#7178) Co-authored-by: Janet Gainer-Dewar Co-authored-by: Tom Wiseman Co-authored-by: Adam Nichols --- .../blob/nio/AzureDirectoryStream.java | 14 +- .../storage/blob/nio/AzureFileSystem.java | 100 ++++++---- .../com/azure/storage/blob/nio/AzurePath.java | 2 +- build.sbt | 5 +- .../blob/BlobFileSystemConfig.scala | 17 +- .../blob/BlobFileSystemManager.scala | 166 ++++++++-------- .../filesystems/blob/BlobPathBuilder.scala | 34 ++-- .../blob/BlobPathBuilderFactory.scala | 21 +- .../WorkspaceManagerApiClientProvider.scala | 46 ++++- .../org.mockito.plugins.MockMaker | 1 + .../blob/AzureFileSystemSpec.scala | 25 +++ .../blob/BlobFileSystemConfigSpec.scala | 25 +-- .../blob/BlobPathBuilderFactorySpec.scala | 185 +++++++++++------- .../blob/BlobPathBuilderSpec.scala | 61 ++---- 14 files changed, 399 insertions(+), 303 deletions(-) create mode 100644 filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker create mode 100644 filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala diff --git a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureDirectoryStream.java b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureDirectoryStream.java index 917f712ddfc..817121e958e 100644 --- a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureDirectoryStream.java +++ b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureDirectoryStream.java @@ -3,12 +3,6 @@ package com.azure.storage.blob.nio; -import com.azure.core.util.logging.ClientLogger; -import com.azure.storage.blob.BlobContainerClient; -import com.azure.storage.blob.models.BlobItem; -import com.azure.storage.blob.models.BlobListDetails; -import com.azure.storage.blob.models.ListBlobsOptions; - import java.io.IOException; import java.nio.file.DirectoryIteratorException; import java.nio.file.DirectoryStream; @@ -18,6 +12,12 @@ import java.util.NoSuchElementException; import java.util.Set; +import com.azure.core.util.logging.ClientLogger; +import com.azure.storage.blob.BlobContainerClient; +import com.azure.storage.blob.models.BlobItem; +import com.azure.storage.blob.models.BlobListDetails; +import com.azure.storage.blob.models.ListBlobsOptions; + /** * A type for iterating over the contents of a directory. * @@ -88,7 +88,7 @@ private static class AzureDirectoryIterator implements Iterator { if (path.isRoot()) { String containerName = path.toString().substring(0, path.toString().length() - 1); AzureFileSystem afs = ((AzureFileSystem) path.getFileSystem()); - containerClient = ((AzureFileStore) afs.getFileStore(containerName)).getContainerClient(); + containerClient = ((AzureFileStore) afs.getFileStore()).getContainerClient(); } else { AzureResource azureResource = new AzureResource(path); listOptions.setPrefix(azureResource.getBlobClient().getBlobName() + AzureFileSystem.PATH_SEPARATOR); diff --git a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java index 6f981b1b45e..381ed0289d7 100644 --- a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java +++ b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java @@ -3,19 +3,6 @@ package com.azure.storage.blob.nio; -import com.azure.core.credential.AzureSasCredential; -import com.azure.core.http.HttpClient; -import com.azure.core.http.policy.HttpLogDetailLevel; -import com.azure.core.http.policy.HttpPipelinePolicy; -import com.azure.core.util.CoreUtils; -import com.azure.core.util.logging.ClientLogger; -import com.azure.storage.blob.BlobServiceClient; -import com.azure.storage.blob.BlobServiceClientBuilder; -import com.azure.storage.blob.implementation.util.BlobUserAgentModificationPolicy; -import com.azure.storage.common.StorageSharedKeyCredential; -import com.azure.storage.common.policy.RequestRetryOptions; -import com.azure.storage.common.policy.RetryPolicyType; - import java.io.IOException; import java.nio.file.FileStore; import java.nio.file.FileSystem; @@ -27,14 +14,31 @@ import java.nio.file.attribute.FileAttributeView; import java.nio.file.attribute.UserPrincipalLookupService; import java.nio.file.spi.FileSystemProvider; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.regex.PatternSyntaxException; -import java.util.stream.Collectors; + +import com.azure.core.credential.AzureSasCredential; +import com.azure.core.http.HttpClient; +import com.azure.core.http.policy.HttpLogDetailLevel; +import com.azure.core.http.policy.HttpPipelinePolicy; +import com.azure.core.util.CoreUtils; +import com.azure.core.util.logging.ClientLogger; +import com.azure.storage.blob.BlobServiceClient; +import com.azure.storage.blob.BlobServiceClientBuilder; +import com.azure.storage.blob.implementation.util.BlobUserAgentModificationPolicy; +import com.azure.storage.common.StorageSharedKeyCredential; +import com.azure.storage.common.policy.RequestRetryOptions; +import com.azure.storage.common.policy.RetryPolicyType; /** * Implement's Java's {@link FileSystem} interface for Azure Blob Storage. @@ -67,6 +71,11 @@ public final class AzureFileSystem extends FileSystem { */ public static final String AZURE_STORAGE_SAS_TOKEN_CREDENTIAL = "AzureStorageSasTokenCredential"; + /** + * Expected type: String + */ + public static final String AZURE_STORAGE_PUBLIC_ACCESS_CREDENTIAL = "AzureStoragePublicAccessCredential"; + /** * Expected type: com.azure.core.http.policy.HttpLogLevelDetail */ @@ -159,9 +168,9 @@ public final class AzureFileSystem extends FileSystem { private final Long putBlobThreshold; private final Integer maxConcurrencyPerRequest; private final Integer downloadResumeRetries; - private final Map fileStores; private FileStore defaultFileStore; private boolean closed; + private Instant expiry; AzureFileSystem(AzureFileSystemProvider parentFileSystemProvider, String endpoint, Map config) throws IOException { @@ -181,7 +190,7 @@ public final class AzureFileSystem extends FileSystem { this.downloadResumeRetries = (Integer) config.get(AZURE_STORAGE_DOWNLOAD_RESUME_RETRIES); // Initialize and ensure access to FileStores. - this.fileStores = this.initializeFileStores(config); + this.defaultFileStore = this.initializeFileStore(config); } catch (RuntimeException e) { throw LoggingUtility.logError(LOGGER, new IllegalArgumentException("There was an error parsing the " + "configurations map. Please ensure all fields are set to a legal value of the correct type.", e)); @@ -221,7 +230,7 @@ public FileSystemProvider provider() { @Override public void close() throws IOException { this.closed = true; - this.parentFileSystemProvider.closeFileSystem(this.getFileSystemUrl()); + this.parentFileSystemProvider.closeFileSystem(this.getFileSystemUrl() + "/" + defaultFileStore.name()); } /** @@ -282,9 +291,7 @@ public Iterable getRootDirectories() { If the file system was set to use all containers in the account, the account will be re-queried and the list may grow or shrink if containers were added or deleted. */ - return fileStores.keySet().stream() - .map(name -> this.getPath(name + AzurePath.ROOT_DIR_SUFFIX)) - .collect(Collectors.toList()); + return Arrays.asList(this.getPath(defaultFileStore.name() + AzurePath.ROOT_DIR_SUFFIX)); } /** @@ -304,7 +311,7 @@ public Iterable getFileStores() { If the file system was set to use all containers in the account, the account will be re-queried and the list may grow or shrink if containers were added or deleted. */ - return this.fileStores.values(); + return Arrays.asList(defaultFileStore); } /** @@ -397,6 +404,12 @@ private BlobServiceClient buildBlobServiceClient(String endpoint, Map builder.credential((StorageSharedKeyCredential) config.get(AZURE_STORAGE_SHARED_KEY_CREDENTIAL)); } else if (config.containsKey(AZURE_STORAGE_SAS_TOKEN_CREDENTIAL)) { builder.credential((AzureSasCredential) config.get(AZURE_STORAGE_SAS_TOKEN_CREDENTIAL)); + this.setExpiryFromSAS((AzureSasCredential) config.get(AZURE_STORAGE_SAS_TOKEN_CREDENTIAL)); + } else if (config.containsKey(AZURE_STORAGE_PUBLIC_ACCESS_CREDENTIAL)) { + // The Blob Service Client Builder requires at least one kind of authentication to make requests + // For public files however, this is unnecessary. This key-value pair is to denote the case + // explicitly when we supply a placeholder SAS credential to bypass this requirement. + builder.credential((AzureSasCredential) config.get(AZURE_STORAGE_PUBLIC_ACCESS_CREDENTIAL)); } else { throw LoggingUtility.logError(LOGGER, new IllegalArgumentException(String.format("No credentials were " + "provided. Please specify one of the following when constructing an AzureFileSystem: %s, %s.", @@ -430,23 +443,17 @@ private BlobServiceClient buildBlobServiceClient(String endpoint, Map return builder.buildClient(); } - private Map initializeFileStores(Map config) throws IOException { - String fileStoreNames = (String) config.get(AZURE_STORAGE_FILE_STORES); - if (CoreUtils.isNullOrEmpty(fileStoreNames)) { + private FileStore initializeFileStore(Map config) throws IOException { + String fileStoreName = (String) config.get(AZURE_STORAGE_FILE_STORES); + if (CoreUtils.isNullOrEmpty(fileStoreName)) { throw LoggingUtility.logError(LOGGER, new IllegalArgumentException("The list of FileStores cannot be " + "null.")); } Boolean skipConnectionCheck = (Boolean) config.get(AZURE_STORAGE_SKIP_INITIAL_CONTAINER_CHECK); Map fileStores = new HashMap<>(); - for (String fileStoreName : fileStoreNames.split(",")) { - FileStore fs = new AzureFileStore(this, fileStoreName, skipConnectionCheck); - if (this.defaultFileStore == null) { - this.defaultFileStore = fs; - } - fileStores.put(fileStoreName, fs); - } - return fileStores; + this.defaultFileStore = new AzureFileStore(this, fileStoreName, skipConnectionCheck); + return this.defaultFileStore; } @Override @@ -470,12 +477,11 @@ Path getDefaultDirectory() { return this.getPath(this.defaultFileStore.name() + AzurePath.ROOT_DIR_SUFFIX); } - FileStore getFileStore(String name) throws IOException { - FileStore store = this.fileStores.get(name); - if (store == null) { - throw LoggingUtility.logError(LOGGER, new IOException("Invalid file store: " + name)); + FileStore getFileStore() throws IOException { + if (this.defaultFileStore == null) { + throw LoggingUtility.logError(LOGGER, new IOException("FileStore not initialized")); } - return store; + return defaultFileStore; } Long getBlockSize() { @@ -489,4 +495,24 @@ Long getPutBlobThreshold() { Integer getMaxConcurrencyPerRequest() { return this.maxConcurrencyPerRequest; } + + public Optional getExpiry() { + return Optional.ofNullable(expiry); + } + + private void setExpiryFromSAS(AzureSasCredential token) { + List strings = Arrays.asList(token.getSignature().split("&")); + Optional expiryString = strings.stream() + .filter(s -> s.startsWith("se")) + .findFirst() + .map(s -> s.replaceFirst("se=","")) + .map(s -> s.replace("%3A", ":")); + this.expiry = expiryString.map(es -> Instant.parse(es)).orElse(null); + } + + public boolean isExpired(Duration buffer) { + return Optional.ofNullable(this.expiry) + .map(e -> Instant.now().plus(buffer).isAfter(e)) + .orElse(true); + } } diff --git a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzurePath.java b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzurePath.java index 9742af1f696..917895ba39e 100644 --- a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzurePath.java +++ b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzurePath.java @@ -735,7 +735,7 @@ public BlobClient toBlobClient() throws IOException { String fileStoreName = this.rootToFileStore(root.toString()); BlobContainerClient containerClient = - ((AzureFileStore) this.parentFileSystem.getFileStore(fileStoreName)).getContainerClient(); + ((AzureFileStore) this.parentFileSystem.getFileStore()).getContainerClient(); String blobName = this.withoutRoot(); if (blobName.isEmpty()) { diff --git a/build.sbt b/build.sbt index 814e966b364..2c9a8068992 100644 --- a/build.sbt +++ b/build.sbt @@ -103,10 +103,11 @@ lazy val azureBlobNio = (project in file("azure-blob-nio")) lazy val azureBlobFileSystem = (project in file("filesystems/blob")) .withLibrarySettings("cromwell-azure-blobFileSystem", blobFileSystemDependencies) .dependsOn(core) - .dependsOn(core % "test->test") - .dependsOn(common % "test->test") .dependsOn(cloudSupport) .dependsOn(azureBlobNio) + .dependsOn(core % "test->test") + .dependsOn(common % "test->test") + .dependsOn(azureBlobNio % "test->test") lazy val awsS3FileSystem = (project in file("filesystems/s3")) .withLibrarySettings("cromwell-aws-s3filesystem", s3FileSystemDependencies) diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemConfig.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemConfig.scala index f68bf7f5176..c5467c78ffe 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemConfig.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemConfig.scala @@ -11,13 +11,9 @@ import java.util.UUID // WSM config is needed for accessing WSM-managed blob containers created in Terra workspaces. // If the identity executing Cromwell has native access to the blob container, this can be ignored. final case class WorkspaceManagerConfig(url: WorkspaceManagerURL, - workspaceId: WorkspaceId, - containerResourceId: ContainerResourceId, overrideWsmAuthToken: Option[String]) // dev-only -final case class BlobFileSystemConfig(endpointURL: EndpointURL, - blobContainerName: BlobContainerName, - subscriptionId: Option[SubscriptionId], +final case class BlobFileSystemConfig(subscriptionId: Option[SubscriptionId], expiryBufferMinutes: Long, workspaceManagerConfig: Option[WorkspaceManagerConfig]) @@ -26,8 +22,6 @@ object BlobFileSystemConfig { final val defaultExpiryBufferMinutes = 10L def apply(config: Config): BlobFileSystemConfig = { - val endpointURL = parseString(config, "endpoint").map(EndpointURL) - val blobContainer = parseString(config, "container").map(BlobContainerName) val subscriptionId = parseUUIDOpt(config, "subscription").map(_.map(SubscriptionId)) val expiryBufferMinutes = parseLongOpt(config, "expiry-buffer-minutes") @@ -37,17 +31,15 @@ object BlobFileSystemConfig { if (config.hasPath("workspace-manager")) { val wsmConf = config.getConfig("workspace-manager") val wsmURL = parseString(wsmConf, "url").map(WorkspaceManagerURL) - val workspaceId = parseUUID(wsmConf, "workspace-id").map(WorkspaceId) - val containerResourceId = parseUUID(wsmConf, "container-resource-id").map(ContainerResourceId) val overrideWsmAuthToken = parseStringOpt(wsmConf, "b2cToken") - (wsmURL, workspaceId, containerResourceId, overrideWsmAuthToken) + (wsmURL, overrideWsmAuthToken) .mapN(WorkspaceManagerConfig) .map(Option(_)) } else None.validNel - (endpointURL, blobContainer, subscriptionId, expiryBufferMinutes, wsmConfig) + (subscriptionId, expiryBufferMinutes, wsmConfig) .mapN(BlobFileSystemConfig.apply) .unsafe("Couldn't parse blob filesystem config") } @@ -58,9 +50,6 @@ object BlobFileSystemConfig { private def parseStringOpt(config: Config, path: String) = validate[Option[String]] { config.as[Option[String]](path) } - private def parseUUID(config: Config, path: String) = - validate[UUID] { UUID.fromString(config.as[String](path)) } - private def parseUUIDOpt(config: Config, path: String) = validate[Option[UUID]] { config.as[Option[String]](path).map(UUID.fromString) } diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala index ac8f01d2cc7..6b6088c7689 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala @@ -1,5 +1,6 @@ package cromwell.filesystems.blob +import bio.terra.workspace.client.ApiException import com.azure.core.credential.AzureSasCredential import com.azure.storage.blob.nio.{AzureFileSystem, AzureFileSystemProvider} import com.azure.storage.blob.sas.{BlobContainerSasPermission, BlobServiceSasSignatureValues} @@ -9,18 +10,19 @@ import common.validation.Validation._ import cromwell.cloudsupport.azure.{AzureCredentials, AzureUtils} import java.net.URI +import java.nio.file._ import java.nio.file.spi.FileSystemProvider -import java.nio.file.{FileSystem, FileSystemNotFoundException} import java.time.temporal.ChronoUnit import java.time.{Duration, Instant, OffsetDateTime} +import java.util.UUID import scala.jdk.CollectionConverters._ import scala.util.{Failure, Success, Try} // We encapsulate this functionality here so that we can easily mock it out, to allow for testing without // actually connecting to Blob storage. -case class FileSystemAPI(private val provider: FileSystemProvider = new AzureFileSystemProvider()) { - def getFileSystem(uri: URI): Try[FileSystem] = Try(provider.getFileSystem(uri)) - def newFileSystem(uri: URI, config: Map[String, Object]): FileSystem = provider.newFileSystem(uri, config.asJava) +case class AzureFileSystemAPI(private val provider: FileSystemProvider = new AzureFileSystemProvider()) { + def getFileSystem(uri: URI): Try[AzureFileSystem] = Try(provider.getFileSystem(uri).asInstanceOf[AzureFileSystem]) + def newFileSystem(uri: URI, config: Map[String, Object]): Try[AzureFileSystem] = Try(provider.newFileSystem(uri, config.asJava).asInstanceOf[AzureFileSystem]) def closeFileSystem(uri: URI): Option[Unit] = getFileSystem(uri).toOption.map(_.close) } /** @@ -36,25 +38,25 @@ object BlobFileSystemManager { } yield instant def buildConfigMap(credential: AzureSasCredential, container: BlobContainerName): Map[String, Object] = { - Map((AzureFileSystem.AZURE_STORAGE_SAS_TOKEN_CREDENTIAL, credential), - (AzureFileSystem.AZURE_STORAGE_FILE_STORES, container.value), - (AzureFileSystem.AZURE_STORAGE_SKIP_INITIAL_CONTAINER_CHECK, java.lang.Boolean.TRUE)) + // Special handling is done here to provide a special key value pair if the placeholder token is provided + // This is due to the BlobClient requiring an auth token even for public blob paths. + val sasTuple = if (credential == PLACEHOLDER_TOKEN) (AzureFileSystem.AZURE_STORAGE_PUBLIC_ACCESS_CREDENTIAL, PLACEHOLDER_TOKEN) + else (AzureFileSystem.AZURE_STORAGE_SAS_TOKEN_CREDENTIAL, credential) + + Map(sasTuple, (AzureFileSystem.AZURE_STORAGE_FILE_STORES, container.value), + (AzureFileSystem.AZURE_STORAGE_SKIP_INITIAL_CONTAINER_CHECK, java.lang.Boolean.TRUE)) } - def hasTokenExpired(tokenExpiry: Instant, buffer: Duration): Boolean = Instant.now.plus(buffer).isAfter(tokenExpiry) - def uri(endpoint: EndpointURL) = new URI("azb://?endpoint=" + endpoint) + def combinedEnpointContainerUri(endpoint: EndpointURL, container: BlobContainerName) = new URI("azb://?endpoint=" + endpoint + "/" + container.value) + + val PLACEHOLDER_TOKEN = new AzureSasCredential("this-is-a-public-sas") } -class BlobFileSystemManager(val container: BlobContainerName, - val endpoint: EndpointURL, - val expiryBufferMinutes: Long, +class BlobFileSystemManager(val expiryBufferMinutes: Long, val blobTokenGenerator: BlobSasTokenGenerator, - val fileSystemAPI: FileSystemAPI = FileSystemAPI(), - private val initialExpiration: Option[Instant] = None) extends LazyLogging { + val fileSystemAPI: AzureFileSystemAPI = AzureFileSystemAPI()) extends LazyLogging { def this(config: BlobFileSystemConfig) = { this( - config.blobContainerName, - config.endpointURL, config.expiryBufferMinutes, BlobSasTokenGenerator.createBlobTokenGeneratorFromConfig(config) ) @@ -63,39 +65,46 @@ class BlobFileSystemManager(val container: BlobContainerName, def this(rawConfig: Config) = this(BlobFileSystemConfig(rawConfig)) val buffer: Duration = Duration.of(expiryBufferMinutes, ChronoUnit.MINUTES) - private var expiry: Option[Instant] = initialExpiration - def getExpiry: Option[Instant] = expiry - def uri: URI = BlobFileSystemManager.uri(endpoint) - def isTokenExpired: Boolean = expiry.exists(BlobFileSystemManager.hasTokenExpired(_, buffer)) - def shouldReopenFilesystem: Boolean = isTokenExpired || expiry.isEmpty - def retrieveFilesystem(): Try[FileSystem] = { + def retrieveFilesystem(endpoint: EndpointURL, container: BlobContainerName): Try[FileSystem] = { + val uri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint, container) synchronized { - shouldReopenFilesystem match { - case false => fileSystemAPI.getFileSystem(uri).recoverWith { - // If no filesystem already exists, this will create a new connection, with the provided configs - case _: FileSystemNotFoundException => - logger.info(s"Creating new blob filesystem for URI $uri") - blobTokenGenerator.generateBlobSasToken.flatMap(generateFilesystem(uri, container, _)) + fileSystemAPI.getFileSystem(uri).filter(!_.isExpired(buffer)).recoverWith { + // If no filesystem already exists, this will create a new connection, with the provided configs + case _: FileSystemNotFoundException => { + logger.info(s"Creating new blob filesystem for URI $uri") + generateFilesystem(uri, container, endpoint) } - // If the token has expired, OR there is no token record, try to close the FS and regenerate - case true => + case _ : NoSuchElementException => { + // When the filesystem expires, the above filter results in a + // NoSuchElementException. If expired, close the filesystem + // and reopen the filesystem with the fresh token logger.info(s"Closing & regenerating token for existing blob filesystem at URI $uri") fileSystemAPI.closeFileSystem(uri) - blobTokenGenerator.generateBlobSasToken.flatMap(generateFilesystem(uri, container, _)) + generateFilesystem(uri, container, endpoint) + } } } } - private def generateFilesystem(uri: URI, container: BlobContainerName, token: AzureSasCredential): Try[FileSystem] = { - expiry = BlobFileSystemManager.parseTokenExpiry(token) - if (expiry.isEmpty) return Failure(new Exception("Could not reopen filesystem, no expiration found")) - Try(fileSystemAPI.newFileSystem(uri, BlobFileSystemManager.buildConfigMap(token, container))) + /** + * Create a new filesystem pointing to a particular container and storage account, + * generating a SAS token from WSM as needed + * + * @param uri a URI formatted to include the scheme, storage account endpoint and container + * @param container the container to open as a filesystem + * @param endpoint the endpoint containing the storage account for the container to open + * @return a try with either the successfully created filesystem, or a failure containing the exception + */ + private def generateFilesystem(uri: URI, container: BlobContainerName, endpoint: EndpointURL): Try[AzureFileSystem] = { + blobTokenGenerator.generateBlobSasToken(endpoint, container) + .flatMap((token: AzureSasCredential) => { + fileSystemAPI.newFileSystem(uri, BlobFileSystemManager.buildConfigMap(token, container)) + }) } - } -sealed trait BlobSasTokenGenerator { def generateBlobSasToken: Try[AzureSasCredential] } +sealed trait BlobSasTokenGenerator { def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] } object BlobSasTokenGenerator { /** @@ -122,35 +131,23 @@ object BlobSasTokenGenerator { // WSM-mediated mediated SAS token generator // parameterizing client instead of URL to make injecting mock client possible - BlobSasTokenGenerator.createBlobTokenGenerator( - config.blobContainerName, - config.endpointURL, - wsmConfig.workspaceId, - wsmConfig.containerResourceId, - wsmClient, - wsmConfig.overrideWsmAuthToken - ) + BlobSasTokenGenerator.createBlobTokenGenerator(wsmClient, wsmConfig.overrideWsmAuthToken) }.getOrElse( // Native SAS token generator - BlobSasTokenGenerator.createBlobTokenGenerator(config.blobContainerName, config.endpointURL, config.subscriptionId) + BlobSasTokenGenerator.createBlobTokenGenerator(config.subscriptionId) ) /** * Native SAS token generator, uses the DefaultAzureCredentialBuilder in the local environment * to produce a SAS token. * - * @param container The BlobContainerName of the blob container to be accessed by the generated SAS token - * @param endpoint The EndpointURL containing the storage account of the blob container to be accessed by - * this SAS token * @param subscription Optional subscription parameter to use for local authorization. * If one is not provided the default subscription is used * @return A NativeBlobTokenGenerator, able to produce a valid SAS token for accessing the provided blob * container and endpoint locally */ - def createBlobTokenGenerator(container: BlobContainerName, - endpoint: EndpointURL, - subscription: Option[SubscriptionId]): BlobSasTokenGenerator = { - NativeBlobSasTokenGenerator(container, endpoint, subscription) + def createBlobTokenGenerator(subscription: Option[SubscriptionId]): BlobSasTokenGenerator = { + NativeBlobSasTokenGenerator(subscription) } /** @@ -158,11 +155,6 @@ object BlobSasTokenGenerator { * to request a SAS token from the WSM to access the given blob container. If an overrideWsmAuthToken * is provided this is used instead. * - * @param container The BlobContainerName of the blob container to be accessed by the generated SAS token - * @param endpoint The EndpointURL containing the storage account of the blob container to be accessed by - * this SAS token - * @param workspaceId The WorkspaceId of the account to authenticate against - * @param containerResourceId The ContainterResourceId of the blob container as WSM knows it * @param workspaceManagerClient The client for making requests against WSM * @param overrideWsmAuthToken An optional WsmAuthToken used for authenticating against the WSM for a valid * SAS token to access the given container and endpoint. This is a dev only option that is only intended @@ -170,54 +162,56 @@ object BlobSasTokenGenerator { * @return A WSMBlobTokenGenerator, able to produce a valid SAS token for accessing the provided blob * container and endpoint that is managed by WSM */ - def createBlobTokenGenerator(container: BlobContainerName, - endpoint: EndpointURL, - workspaceId: WorkspaceId, - containerResourceId: ContainerResourceId, - workspaceManagerClient: WorkspaceManagerApiClientProvider, + def createBlobTokenGenerator(workspaceManagerClient: WorkspaceManagerApiClientProvider, overrideWsmAuthToken: Option[String]): BlobSasTokenGenerator = { - WSMBlobSasTokenGenerator(container, endpoint, workspaceId, containerResourceId, workspaceManagerClient, overrideWsmAuthToken) + WSMBlobSasTokenGenerator(workspaceManagerClient, overrideWsmAuthToken) } } -case class WSMBlobSasTokenGenerator(container: BlobContainerName, - endpoint: EndpointURL, - workspaceId: WorkspaceId, - containerResourceId: ContainerResourceId, - wsmClientProvider: WorkspaceManagerApiClientProvider, +case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClientProvider, overrideWsmAuthToken: Option[String]) extends BlobSasTokenGenerator { /** * Generate a BlobSasToken by using the available authorization information * If an overrideWsmAuthToken is provided, use this in the wsmClient request * Else try to use the environment azure identity to request the SAS token + * @param endpoint The EndpointURL of the blob container to be accessed by the generated SAS token + * @param container The BlobContainerName of the blob container to be accessed by the generated SAS token * * @return an AzureSasCredential for accessing a blob container */ - def generateBlobSasToken: Try[AzureSasCredential] = { + def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] = { val wsmAuthToken: Try[String] = overrideWsmAuthToken match { case Some(t) => Success(t) case None => AzureCredentials.getAccessToken(None).toTry } + container.workspaceId match { + // If this is a Terra workspace, request a token from WSM + case Success(workspaceId) => { + (for { + wsmAuth <- wsmAuthToken + wsmAzureResourceClient = wsmClientProvider.getControlledAzureResourceApi(wsmAuth) + resourceId <- getContainerResourceId(workspaceId, container, wsmAuth) + sasToken <- wsmAzureResourceClient.createAzureStorageContainerSasToken(workspaceId, resourceId) + } yield sasToken).recoverWith { + // If the storage account was still not found in WSM, this may be a public filesystem + case exception: ApiException if exception.getCode == 404 => Try(BlobFileSystemManager.PLACEHOLDER_TOKEN) + } + } + // Otherwise assume that the container is public and use a placeholder + // SAS token to bypass the BlobClient authentication requirement + case Failure(_) => Try(BlobFileSystemManager.PLACEHOLDER_TOKEN) + } + } - for { - wsmAuth <- wsmAuthToken - wsmClient = wsmClientProvider.getControlledAzureResourceApi(wsmAuth) - sasToken <- Try( // Java library throws - wsmClient.createAzureStorageContainerSasToken( - workspaceId.value, - containerResourceId.value, - null, - null, - null, - null - ).getToken) - } yield new AzureSasCredential(sasToken) + def getContainerResourceId(workspaceId: UUID, container: BlobContainerName, wsmAuth : String): Try[UUID] = { + val wsmResourceClient = wsmClientProvider.getResourceApi(wsmAuth) + wsmResourceClient.findContainerResourceId(workspaceId, container) } } -case class NativeBlobSasTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, subscription: Option[SubscriptionId] = None) extends BlobSasTokenGenerator { +case class NativeBlobSasTokenGenerator(subscription: Option[SubscriptionId] = None) extends BlobSasTokenGenerator { private val bcsp = new BlobContainerSasPermission() .setReadPermission(true) .setCreatePermission(true) @@ -227,10 +221,12 @@ case class NativeBlobSasTokenGenerator(container: BlobContainerName, endpoint: E /** * Generate a BlobSasToken by using the local environment azure identity * This will use a default subscription if one is not provided. + * @param endpoint The EndpointURL of the blob container to be accessed by the generated SAS token + * @param container The BlobContainerName of the blob container to be accessed by the generated SAS token * * @return an AzureSasCredential for accessing a blob container */ - def generateBlobSasToken: Try[AzureSasCredential] = for { + def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] = for { bcc <- AzureUtils.buildContainerClientFromLocalEnvironment(container.toString, endpoint.toString, subscription.map(_.toString)) bsssv = new BlobServiceSasSignatureValues(OffsetDateTime.now.plusDays(1), bcsp) asc = new AzureSasCredential(bcc.generateSas(bsssv)) diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala index 35c518c0a43..3aa26eb3c11 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala @@ -12,12 +12,13 @@ import scala.language.postfixOps import scala.util.{Failure, Success, Try} object BlobPathBuilder { - + private val blobHostnameSuffix = ".blob.core.windows.net" sealed trait BlobPathValidation - case class ValidBlobPath(path: String) extends BlobPathValidation + case class ValidBlobPath(path: String, container: BlobContainerName, endpoint: EndpointURL) extends BlobPathValidation case class UnparsableBlobPath(errorMessage: Throwable) extends BlobPathValidation - def invalidBlobPathMessage(container: BlobContainerName, endpoint: EndpointURL) = s"Malformed Blob URL for this builder. Expecting a URL for a container $container and endpoint $endpoint" + def invalidBlobHostMessage(endpoint: EndpointURL) = s"Malformed Blob URL for this builder: The endpoint $endpoint doesn't contain the expected host string '{SA}.blob.core.windows.net/'" + def invalidBlobContainerMessage(endpoint: EndpointURL) = s"Malformed Blob URL for this builder: Could not parse container" def parseURI(string: String): Try[URI] = Try(URI.create(UrlEscapers.urlFragmentEscaper().escape(string))) def parseStorageAccount(uri: URI): Try[StorageAccountName] = uri.getHost.split("\\.").find(_.nonEmpty).map(StorageAccountName(_)) .map(Success(_)).getOrElse(Failure(new Exception("Could not parse storage account"))) @@ -40,28 +41,31 @@ object BlobPathBuilder { * * If the configured container and storage account do not match, the string is considered unparsable */ - def validateBlobPath(string: String, container: BlobContainerName, endpoint: EndpointURL): BlobPathValidation = { + def validateBlobPath(string: String): BlobPathValidation = { val blobValidation = for { testUri <- parseURI(string) - endpointUri <- parseURI(endpoint.value) + testEndpoint = EndpointURL(testUri.getScheme + "://" + testUri.getHost()) testStorageAccount <- parseStorageAccount(testUri) - endpointStorageAccount <- parseStorageAccount(endpointUri) - hasContainer = testUri.getPath.split("/").find(_.nonEmpty).contains(container.value) - hasEndpoint = testStorageAccount.equals(endpointStorageAccount) - blobPathValidation = (hasContainer && hasEndpoint) match { - case true => ValidBlobPath(testUri.getPath.replaceFirst("/" + container, "")) - case false => UnparsableBlobPath(new MalformedURLException(invalidBlobPathMessage(container, endpoint))) + testContainer = testUri.getPath.split("/").find(_.nonEmpty) + isBlobHost = testUri.getHost().contains(blobHostnameSuffix) && testUri.getScheme().contains("https") + blobPathValidation = (isBlobHost, testContainer) match { + case (true, Some(container)) => ValidBlobPath( + testUri.getPath.replaceFirst("/" + container, ""), + BlobContainerName(container), + testEndpoint) + case (false, _) => UnparsableBlobPath(new MalformedURLException(invalidBlobHostMessage(testEndpoint))) + case (true, None) => UnparsableBlobPath(new MalformedURLException(invalidBlobContainerMessage(testEndpoint))) } } yield blobPathValidation blobValidation recover { case t => UnparsableBlobPath(t) } get } } -class BlobPathBuilder(container: BlobContainerName, endpoint: EndpointURL)(private val fsm: BlobFileSystemManager) extends PathBuilder { +class BlobPathBuilder()(private val fsm: BlobFileSystemManager) extends PathBuilder { def build(string: String): Try[BlobPath] = { - validateBlobPath(string, container, endpoint) match { - case ValidBlobPath(path) => Try(BlobPath(path, endpoint, container)(fsm)) + validateBlobPath(string) match { + case ValidBlobPath(path, container, endpoint) => Try(BlobPath(path, endpoint, container)(fsm)) case UnparsableBlobPath(errorMessage: Throwable) => Failure(errorMessage) } } @@ -121,7 +125,7 @@ case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, con override def pathWithoutScheme: String = parseURI(endpoint.value).map(u => List(u.getHost, container, pathString.stripPrefix("/")).mkString("/")).get private def findNioPath(path: String): NioPath = (for { - fileSystem <- fsm.retrieveFilesystem() + fileSystem <- fsm.retrieveFilesystem(endpoint, container) // The Azure NIO library uses `{container}:` to represent the root of the path nioPath = fileSystem.getPath(s"${container.value}:", path) // This is purposefully an unprotected get because the NIO API needing an unwrapped path object. diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilderFactory.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilderFactory.scala index c263841dc8a..47245552dc2 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilderFactory.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilderFactory.scala @@ -8,11 +8,26 @@ import cromwell.core.path.PathBuilderFactory.PriorityBlob import java.util.UUID import scala.concurrent.{ExecutionContext, Future} +import scala.util.Try final case class SubscriptionId(value: UUID) {override def toString: String = value.toString} -final case class BlobContainerName(value: String) {override def toString: String = value} +final case class BlobContainerName(value: String) { + override def toString: String = value + lazy val workspaceId: Try[UUID] = { + Try(UUID.fromString(value.replaceFirst("sc-",""))) + } +} final case class StorageAccountName(value: String) {override def toString: String = value} -final case class EndpointURL(value: String) {override def toString: String = value} +final case class EndpointURL(value: String) { + override def toString: String = value + lazy val storageAccountName : Try[StorageAccountName] = { + val sa = for { + host <- value.split("//").findLast(_.nonEmpty) + storageAccountName <- host.split("\\.").find(_.nonEmpty) + } yield StorageAccountName(storageAccountName) + sa.toRight(new Exception(s"Storage account name could not be parsed from $value")).toTry + } +} final case class WorkspaceId(value: UUID) {override def toString: String = value.toString} final case class ContainerResourceId(value: UUID) {override def toString: String = value.toString} final case class WorkspaceManagerURL(value: String) {override def toString: String = value} @@ -21,7 +36,7 @@ final case class BlobPathBuilderFactory(globalConfig: Config, instanceConfig: Co override def withOptions(options: WorkflowOptions)(implicit as: ActorSystem, ec: ExecutionContext): Future[BlobPathBuilder] = { Future { - new BlobPathBuilder(fsm.container, fsm.endpoint)(fsm) + new BlobPathBuilder()(fsm) } } diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala index a9f52d92a91..276738c98b6 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala @@ -1,7 +1,13 @@ package cromwell.filesystems.blob -import bio.terra.workspace.api.ControlledAzureResourceApi +import bio.terra.workspace.api._ import bio.terra.workspace.client.ApiClient +import bio.terra.workspace.model.{ResourceType, StewardshipType} +import com.azure.core.credential.AzureSasCredential + +import java.util.UUID +import scala.jdk.CollectionConverters._ +import scala.util.Try /** * Represents a way to get a client for interacting with workspace manager controlled resources. @@ -12,7 +18,8 @@ import bio.terra.workspace.client.ApiClient * For testing, create an anonymous subclass as in `org.broadinstitute.dsde.rawls.dataaccess.workspacemanager.HttpWorkspaceManagerDAOSpec` */ trait WorkspaceManagerApiClientProvider { - def getControlledAzureResourceApi(token: String): ControlledAzureResourceApi + def getControlledAzureResourceApi(token: String): WsmControlledAzureResourceApi + def getResourceApi(token: String): WsmResourceApi } class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManagerURL) extends WorkspaceManagerApiClientProvider { @@ -22,9 +29,40 @@ class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManag client } - def getControlledAzureResourceApi(token: String): ControlledAzureResourceApi = { + def getResourceApi(token: String): WsmResourceApi = { + val apiClient = getApiClient + apiClient.setAccessToken(token) + WsmResourceApi(new ResourceApi(apiClient)) + } + + def getControlledAzureResourceApi(token: String): WsmControlledAzureResourceApi = { val apiClient = getApiClient apiClient.setAccessToken(token) - new ControlledAzureResourceApi(apiClient) + WsmControlledAzureResourceApi(new ControlledAzureResourceApi(apiClient)) + } +} + +case class WsmResourceApi(resourcesApi : ResourceApi) { + def findContainerResourceId(workspaceId : UUID, container: BlobContainerName): Try[UUID] = { + for { + workspaceResources <- Try(resourcesApi.enumerateResources(workspaceId, 0, 10, ResourceType.AZURE_STORAGE_CONTAINER, StewardshipType.CONTROLLED).getResources()) + workspaceStorageContainerOption = workspaceResources.asScala.find(r => r.getMetadata().getName() == container.value) + workspaceStorageContainer <- workspaceStorageContainerOption.toRight(new Exception("No storage container found for this workspace")).toTry + resourceId = workspaceStorageContainer.getMetadata().getResourceId() + } yield resourceId + } +} +case class WsmControlledAzureResourceApi(controlledAzureResourceApi : ControlledAzureResourceApi) { + def createAzureStorageContainerSasToken(workspaceId: UUID, resourceId: UUID): Try[AzureSasCredential] = { + for { + sas <- Try(controlledAzureResourceApi.createAzureStorageContainerSasToken( + workspaceId, + resourceId, + null, + null, + null, + null + ).getToken) + } yield new AzureSasCredential(sas) } } diff --git a/filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 00000000000..1f0955d450f --- /dev/null +++ b/filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline diff --git a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala new file mode 100644 index 00000000000..e0463bab740 --- /dev/null +++ b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala @@ -0,0 +1,25 @@ +package cromwell.filesystems.blob + +import com.azure.storage.blob.nio.{AzureFileSystem, AzureFileSystemProvider} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.time.Instant +import scala.compat.java8.OptionConverters._ +import scala.jdk.CollectionConverters._ + +class AzureFileSystemSpec extends AnyFlatSpec with Matchers { + val now = Instant.now() + val container = BlobContainerName("testConainer") + val exampleSas = BlobPathBuilderFactorySpec.buildExampleSasToken(now) + val exampleConfig = BlobFileSystemManager.buildConfigMap(exampleSas, container) + val exampleStorageEndpoint = BlobPathBuilderSpec.buildEndpoint("testStorageAccount") + val exampleCombinedEndpoint = BlobFileSystemManager.combinedEnpointContainerUri(exampleStorageEndpoint, container) + + it should "parse an expiration from a sas token" in { + val provider = new AzureFileSystemProvider() + val filesystem : AzureFileSystem = provider.newFileSystem(exampleCombinedEndpoint, exampleConfig.asJava).asInstanceOf[AzureFileSystem] + filesystem.getExpiry.asScala shouldBe Some(now) + filesystem.getFileStores.asScala.map(_.name()).exists(_ == container.value) shouldBe true + } +} diff --git a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobFileSystemConfigSpec.scala b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobFileSystemConfigSpec.scala index 607ad5606f7..68804113763 100644 --- a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobFileSystemConfigSpec.scala +++ b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobFileSystemConfigSpec.scala @@ -5,14 +5,8 @@ import common.exception.AggregatedMessageException import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import java.util.UUID - class BlobFileSystemConfigSpec extends AnyFlatSpec with Matchers { - private val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") - private val container = BlobContainerName("storageContainer") - private val workspaceId = WorkspaceId(UUID.fromString("B0BAFE77-0000-0000-0000-000000000000")) - private val containerResourceId = ContainerResourceId(UUID.fromString("F00B4911-0000-0000-0000-000000000000")) private val workspaceManagerURL = WorkspaceManagerURL("https://wsm.example.com") private val b2cToken = "b0gus-t0ken" @@ -20,12 +14,8 @@ class BlobFileSystemConfigSpec extends AnyFlatSpec with Matchers { val config = BlobFileSystemConfig( ConfigFactory.parseString( s""" - |container = "$container" - |endpoint = "$endpoint" """.stripMargin) ) - config.blobContainerName should equal(container) - config.endpointURL should equal(endpoint) config.expiryBufferMinutes should equal(BlobFileSystemConfig.defaultExpiryBufferMinutes) } @@ -33,25 +23,17 @@ class BlobFileSystemConfigSpec extends AnyFlatSpec with Matchers { val config = BlobFileSystemConfig( ConfigFactory.parseString( s""" - |container = "$container" - |endpoint = "$endpoint" |expiry-buffer-minutes = "20" |workspace-manager { | url = "$workspaceManagerURL" - | workspace-id = "$workspaceId" - | container-resource-id = "$containerResourceId" | b2cToken = "$b2cToken" |} | """.stripMargin) ) - config.blobContainerName should equal(container) - config.endpointURL should equal(endpoint) config.expiryBufferMinutes should equal(20L) config.workspaceManagerConfig.isDefined shouldBe true config.workspaceManagerConfig.get.url shouldBe workspaceManagerURL - config.workspaceManagerConfig.get.workspaceId shouldBe workspaceId - config.workspaceManagerConfig.get.containerResourceId shouldBe containerResourceId config.workspaceManagerConfig.get.overrideWsmAuthToken.contains(b2cToken) shouldBe true } @@ -59,17 +41,14 @@ class BlobFileSystemConfigSpec extends AnyFlatSpec with Matchers { val rawConfig = ConfigFactory.parseString( s""" - |container = "$container" - |endpoint = "$endpoint" |expiry-buffer-minutes = "10" |workspace-manager { - | url = "$workspaceManagerURL" - | container-resource-id = "$containerResourceId" + | b2cToken = "$b2cToken" |} | """.stripMargin) val error = intercept[AggregatedMessageException](BlobFileSystemConfig(rawConfig)) - error.getMessage should include("No configuration setting found for key 'workspace-id'") + error.getMessage should include("No configuration setting found for key 'url'") } } diff --git a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala index 881cd3669a1..c4ee102c58b 100644 --- a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala +++ b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala @@ -1,16 +1,18 @@ package cromwell.filesystems.blob import com.azure.core.credential.AzureSasCredential +import com.azure.storage.blob.nio.AzureFileSystem import common.mock.MockSugar import org.mockito.Mockito._ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import java.nio.file.{FileSystem, FileSystemNotFoundException} +import java.nio.file.FileSystemNotFoundException import java.time.format.DateTimeFormatter import java.time.temporal.ChronoUnit import java.time.{Duration, Instant, ZoneId} -import scala.util.{Failure, Try} +import java.util.UUID +import scala.util.{Failure, Success, Try} object BlobPathBuilderFactorySpec { @@ -37,23 +39,12 @@ class BlobPathBuilderFactorySpec extends AnyFlatSpec with Matchers with MockSuga expiry should contain(expiryTime) } - it should "verify an unexpired token will be processed as unexpired" in { - val expiryTime = generateTokenExpiration(11L) - val expired = BlobFileSystemManager.hasTokenExpired(expiryTime, Duration.ofMinutes(10L)) - expired shouldBe false - } - - it should "test an expired token will be processed as expired" in { - val expiryTime = generateTokenExpiration(9L) - val expired = BlobFileSystemManager.hasTokenExpired(expiryTime, Duration.ofMinutes(10L)) - expired shouldBe true - } - it should "test that a filesystem gets closed correctly" in { val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") - val azureUri = BlobFileSystemManager.uri(endpoint) - val fileSystems = mock[FileSystemAPI] - val fileSystem = mock[FileSystem] + val container = BlobContainerName("test") + val azureUri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint, container) + val fileSystems = mock[AzureFileSystemAPI] + val fileSystem = mock[AzureFileSystem] when(fileSystems.getFileSystem(azureUri)).thenReturn(Try(fileSystem)) when(fileSystems.closeFileSystem(azureUri)).thenCallRealMethod() @@ -61,106 +52,156 @@ class BlobPathBuilderFactorySpec extends AnyFlatSpec with Matchers with MockSuga verify(fileSystem, times(1)).close() } - it should "test retrieveFileSystem with expired filesystem" in { + it should "test retrieveFileSystem with expired Terra filesystem" in { val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") - val expiredToken = generateTokenExpiration(9L) + //val expiredToken = generateTokenExpiration(9L) val refreshedToken = generateTokenExpiration(69L) val sasToken = BlobPathBuilderFactorySpec.buildExampleSasToken(refreshedToken) - val container = BlobContainerName("storageContainer") + val container = BlobContainerName("sc-" + UUID.randomUUID().toString()) val configMap = BlobFileSystemManager.buildConfigMap(sasToken, container) - val azureUri = BlobFileSystemManager.uri(endpoint) - - val fileSystems = mock[FileSystemAPI] + val azureUri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint, container) + + //Mocking this final class requires the plugin Mock Maker Inline plugin, configured here + //at filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker + val azureFileSystem = mock[AzureFileSystem] + when(azureFileSystem.isExpired(Duration.ofMinutes(10L))).thenReturn(true) + val fileSystems = mock[AzureFileSystemAPI] + when(fileSystems.getFileSystem(azureUri)).thenReturn(Success(azureFileSystem)) val blobTokenGenerator = mock[BlobSasTokenGenerator] - when(blobTokenGenerator.generateBlobSasToken).thenReturn(Try(sasToken)) + when(blobTokenGenerator.generateBlobSasToken(endpoint, container)).thenReturn(Try(sasToken)) - val fsm = new BlobFileSystemManager(container, endpoint, 10L, blobTokenGenerator, fileSystems, Some(expiredToken)) - fsm.getExpiry should contain(expiredToken) - fsm.isTokenExpired shouldBe true - fsm.retrieveFilesystem() + val fsm = new BlobFileSystemManager(10L, blobTokenGenerator, fileSystems) + fsm.retrieveFilesystem(endpoint, container) - fsm.getExpiry should contain(refreshedToken) - fsm.isTokenExpired shouldBe false - verify(fileSystems, never()).getFileSystem(azureUri) + verify(fileSystems, times(1)).getFileSystem(azureUri) verify(fileSystems, times(1)).newFileSystem(azureUri, configMap) verify(fileSystems, times(1)).closeFileSystem(azureUri) } - it should "test retrieveFileSystem with an unexpired fileSystem" in { + it should "test retrieveFileSystem with an unexpired Terra fileSystem" in { val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") - val initialToken = generateTokenExpiration(11L) + //val initialToken = generateTokenExpiration(11L) val refreshedToken = generateTokenExpiration(71L) val sasToken = BlobPathBuilderFactorySpec.buildExampleSasToken(refreshedToken) - val container = BlobContainerName("storageContainer") + val container = BlobContainerName("sc-" + UUID.randomUUID().toString()) val configMap = BlobFileSystemManager.buildConfigMap(sasToken, container) - val azureUri = BlobFileSystemManager.uri(endpoint) - // Need a fake filesystem to supply the getFileSystem simulated try - val dummyFileSystem = mock[FileSystem] + val azureUri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint,container) - val fileSystems = mock[FileSystemAPI] - when(fileSystems.getFileSystem(azureUri)).thenReturn(Try(dummyFileSystem)) + //Mocking this final class requires the plugin Mock Maker Inline plugin, configured here + //at filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker + val azureFileSystem = mock[AzureFileSystem] + when(azureFileSystem.isExpired(Duration.ofMinutes(10L))).thenReturn(false) + val fileSystems = mock[AzureFileSystemAPI] + when(fileSystems.getFileSystem(azureUri)).thenReturn(Try(azureFileSystem)) val blobTokenGenerator = mock[BlobSasTokenGenerator] - when(blobTokenGenerator.generateBlobSasToken).thenReturn(Try(sasToken)) + when(blobTokenGenerator.generateBlobSasToken(endpoint, container)).thenReturn(Try(sasToken)) - val fsm = new BlobFileSystemManager(container, endpoint, 10L, blobTokenGenerator, fileSystems, Some(initialToken)) - fsm.getExpiry should contain(initialToken) - fsm.isTokenExpired shouldBe false - fsm.retrieveFilesystem() + val fsm = new BlobFileSystemManager(10L, blobTokenGenerator, fileSystems) + fsm.retrieveFilesystem(endpoint, container) - fsm.getExpiry should contain(initialToken) - fsm.isTokenExpired shouldBe false verify(fileSystems, times(1)).getFileSystem(azureUri) verify(fileSystems, never()).newFileSystem(azureUri, configMap) verify(fileSystems, never()).closeFileSystem(azureUri) } - it should "test retrieveFileSystem with an uninitialized filesystem" in { + it should "test retrieveFileSystem with an uninitialized Terra filesystem" in { val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") val refreshedToken = generateTokenExpiration(71L) val sasToken = BlobPathBuilderFactorySpec.buildExampleSasToken(refreshedToken) - val container = BlobContainerName("storageContainer") + val container = BlobContainerName("sc-" + UUID.randomUUID().toString()) val configMap = BlobFileSystemManager.buildConfigMap(sasToken, container) - val azureUri = BlobFileSystemManager.uri(endpoint) + val azureUri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint, container) - val fileSystems = mock[FileSystemAPI] + //Mocking this final class requires the plugin Mock Maker Inline plugin, configured here + //at filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker + val azureFileSystem = mock[AzureFileSystem] + when(azureFileSystem.isExpired(Duration.ofMinutes(10L))).thenReturn(false) + val fileSystems = mock[AzureFileSystemAPI] when(fileSystems.getFileSystem(azureUri)).thenReturn(Failure(new FileSystemNotFoundException)) + when(fileSystems.newFileSystem(azureUri, configMap)).thenReturn(Try(azureFileSystem)) val blobTokenGenerator = mock[BlobSasTokenGenerator] - when(blobTokenGenerator.generateBlobSasToken).thenReturn(Try(sasToken)) + when(blobTokenGenerator.generateBlobSasToken(endpoint, container)).thenReturn(Try(sasToken)) - val fsm = new BlobFileSystemManager(container, endpoint, 10L, blobTokenGenerator, fileSystems, Some(refreshedToken)) - fsm.getExpiry.isDefined shouldBe true - fsm.isTokenExpired shouldBe false - fsm.retrieveFilesystem() + val fsm = new BlobFileSystemManager(0L, blobTokenGenerator, fileSystems) + fsm.retrieveFilesystem(endpoint, container) - fsm.getExpiry should contain(refreshedToken) - fsm.isTokenExpired shouldBe false verify(fileSystems, times(1)).getFileSystem(azureUri) verify(fileSystems, times(1)).newFileSystem(azureUri, configMap) verify(fileSystems, never()).closeFileSystem(azureUri) } - it should "test retrieveFileSystem with an unknown filesystem" in { + it should "test retrieveFileSystem with expired non-Terra filesystem" in { val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") - val refreshedToken = generateTokenExpiration(71L) - val sasToken = BlobPathBuilderFactorySpec.buildExampleSasToken(refreshedToken) - val container = BlobContainerName("storageContainer") + val sasToken = BlobFileSystemManager.PLACEHOLDER_TOKEN + val container = BlobContainerName("sc-" + UUID.randomUUID().toString()) val configMap = BlobFileSystemManager.buildConfigMap(sasToken, container) - val azureUri = BlobFileSystemManager.uri(endpoint) - - val fileSystems = mock[FileSystemAPI] + val azureUri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint, container) + + //Mocking this final class requires the plugin Mock Maker Inline plugin, configured here + //at filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker + val azureFileSystem = mock[AzureFileSystem] + when(azureFileSystem.isExpired(Duration.ofMinutes(10L))).thenReturn(true) + val fileSystems = mock[AzureFileSystemAPI] + when(fileSystems.getFileSystem(azureUri)).thenReturn(Success(azureFileSystem)) val blobTokenGenerator = mock[BlobSasTokenGenerator] - when(blobTokenGenerator.generateBlobSasToken).thenReturn(Try(sasToken)) + when(blobTokenGenerator.generateBlobSasToken(endpoint, container)).thenReturn(Try(sasToken)) - val fsm = new BlobFileSystemManager(container, endpoint, 10L, blobTokenGenerator, fileSystems) - fsm.getExpiry.isDefined shouldBe false - fsm.isTokenExpired shouldBe false - fsm.retrieveFilesystem() + val fsm = new BlobFileSystemManager(10L, blobTokenGenerator, fileSystems) + fsm.retrieveFilesystem(endpoint, container) - fsm.getExpiry should contain(refreshedToken) - fsm.isTokenExpired shouldBe false - verify(fileSystems, never()).getFileSystem(azureUri) + verify(fileSystems, times(1)).getFileSystem(azureUri) verify(fileSystems, times(1)).newFileSystem(azureUri, configMap) verify(fileSystems, times(1)).closeFileSystem(azureUri) } + + it should "test retrieveFileSystem with an unexpired non-Terra fileSystem" in { + val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") + val sasToken = BlobFileSystemManager.PLACEHOLDER_TOKEN + val container = BlobContainerName("sc-" + UUID.randomUUID().toString()) + val configMap = BlobFileSystemManager.buildConfigMap(sasToken, container) + val azureUri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint,container) + + //Mocking this final class requires the plugin Mock Maker Inline plugin, configured here + //at filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker + val azureFileSystem = mock[AzureFileSystem] + when(azureFileSystem.isExpired(Duration.ofMinutes(10L))).thenReturn(false) + val fileSystems = mock[AzureFileSystemAPI] + when(fileSystems.getFileSystem(azureUri)).thenReturn(Try(azureFileSystem)) + + val blobTokenGenerator = mock[BlobSasTokenGenerator] + when(blobTokenGenerator.generateBlobSasToken(endpoint, container)).thenReturn(Try(sasToken)) + + val fsm = new BlobFileSystemManager(10L, blobTokenGenerator, fileSystems) + fsm.retrieveFilesystem(endpoint, container) + + verify(fileSystems, times(1)).getFileSystem(azureUri) + verify(fileSystems, never()).newFileSystem(azureUri, configMap) + verify(fileSystems, never()).closeFileSystem(azureUri) + } + + it should "test retrieveFileSystem with an uninitialized non-Terra filesystem" in { + val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") + val sasToken = BlobFileSystemManager.PLACEHOLDER_TOKEN + val container = BlobContainerName("sc-" + UUID.randomUUID().toString()) + val configMap = BlobFileSystemManager.buildConfigMap(sasToken, container) + val azureUri = BlobFileSystemManager.combinedEnpointContainerUri(endpoint, container) + + //Mocking this final class requires the plugin Mock Maker Inline plugin, configured here + //at filesystems/blob/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker + val azureFileSystem = mock[AzureFileSystem] + when(azureFileSystem.isExpired(Duration.ofMinutes(10L))).thenReturn(false) + val fileSystems = mock[AzureFileSystemAPI] + when(fileSystems.getFileSystem(azureUri)).thenReturn(Failure(new FileSystemNotFoundException)) + when(fileSystems.newFileSystem(azureUri, configMap)).thenReturn(Try(azureFileSystem)) + val blobTokenGenerator = mock[BlobSasTokenGenerator] + when(blobTokenGenerator.generateBlobSasToken(endpoint, container)).thenReturn(Try(sasToken)) + + val fsm = new BlobFileSystemManager(0L, blobTokenGenerator, fileSystems) + fsm.retrieveFilesystem(endpoint, container) + + verify(fileSystems, times(1)).getFileSystem(azureUri) + verify(fileSystems, times(1)).newFileSystem(azureUri, configMap) + verify(fileSystems, never()).closeFileSystem(azureUri) + } } diff --git a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala index eef6db8e942..a8ca7d58d6f 100644 --- a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala +++ b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderSpec.scala @@ -18,41 +18,23 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { val container = BlobContainerName("container") val evalPath = "/path/to/file" val testString = endpoint.value + "/" + container + evalPath - BlobPathBuilder.validateBlobPath(testString, container, endpoint) match { - case BlobPathBuilder.ValidBlobPath(path) => path should equal(evalPath) + BlobPathBuilder.validateBlobPath(testString) match { + case BlobPathBuilder.ValidBlobPath(path, parsedContainer, parsedEndpoint) => { + path should equal(evalPath) + parsedContainer should equal(container) + parsedEndpoint should equal(endpoint) + } case BlobPathBuilder.UnparsableBlobPath(errorMessage) => fail(errorMessage) } } - it should "bad storage account fails causes URI to fail parse into a path" in { - val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") - val container = BlobContainerName("container") - val evalPath = "/path/to/file" - val testString = BlobPathBuilderSpec.buildEndpoint("badStorageAccount").value + container.value + evalPath - BlobPathBuilder.validateBlobPath(testString, container, endpoint) match { - case BlobPathBuilder.ValidBlobPath(path) => fail(s"Valid path: $path found when verifying mismatched storage account") - case BlobPathBuilder.UnparsableBlobPath(errorMessage) => errorMessage.getMessage should equal(BlobPathBuilder.invalidBlobPathMessage(container, endpoint)) - } - } - - it should "bad container fails causes URI to fail parse into a path" in { - val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") - val container = BlobContainerName("container") - val evalPath = "/path/to/file" - val testString = endpoint.value + "badContainer" + evalPath - BlobPathBuilder.validateBlobPath(testString, container, endpoint) match { - case BlobPathBuilder.ValidBlobPath(path) => fail(s"Valid path: $path found when verifying mismatched container") - case BlobPathBuilder.UnparsableBlobPath(errorMessage) => errorMessage.getMessage should equal(BlobPathBuilder.invalidBlobPathMessage(container, endpoint)) - } - } - it should "provide a readable error when getting an illegal nioPath" in { val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") val container = BlobContainerName("container") val evalPath = "/path/to/file" val exception = new Exception("Failed to do the thing") val fsm = mock[BlobFileSystemManager] - when(fsm.retrieveFilesystem()).thenReturn(Failure(exception)) + when(fsm.retrieveFilesystem(endpoint, container)).thenReturn(Failure(exception)) val path = BlobPath(evalPath, endpoint, container)(fsm) val testException = Try(path.nioPath).failed.toOption testException should contain(exception) @@ -95,15 +77,14 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { private val endpoint: EndpointURL = BlobPathBuilderSpec.buildEndpoint("centaurtesting") private val container: BlobContainerName = BlobContainerName("test-blob") - def makeBlobPathBuilder(blobEndpoint: EndpointURL, - container: BlobContainerName): BlobPathBuilder = { - val blobTokenGenerator = NativeBlobSasTokenGenerator(container, blobEndpoint, Some(subscriptionId)) - val fsm = new BlobFileSystemManager(container, blobEndpoint, 10, blobTokenGenerator) - new BlobPathBuilder(container, blobEndpoint)(fsm) + def makeBlobPathBuilder(): BlobPathBuilder = { + val blobTokenGenerator = NativeBlobSasTokenGenerator(Some(subscriptionId)) + val fsm = new BlobFileSystemManager(10, blobTokenGenerator) + new BlobPathBuilder()(fsm) } it should "read md5 from small files <5g" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val evalPath = "/testRead.txt" val testString = endpoint.value + "/" + container + evalPath val blobPath1: BlobPath = (builder build testString).get @@ -111,7 +92,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "read md5 from large files >5g" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val evalPath = "/Rocky-9.2-aarch64-dvd.iso" val testString = endpoint.value + "/" + container + evalPath val blobPath1: BlobPath = (builder build testString).get @@ -119,7 +100,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "choose the root/metadata md5 over the native md5 for files that have both" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val evalPath = "/redundant_md5_test.txt" val testString = endpoint.value + "/" + container + evalPath val blobPath1: BlobPath = (builder build testString).get @@ -127,7 +108,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "gracefully return `None` when neither hash is found" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val evalPath = "/no_md5_test.txt" val testString = endpoint.value + "/" + container + evalPath val blobPath1: BlobPath = (builder build testString).get @@ -135,7 +116,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "resolve an absolute path string correctly to a path" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val rootString = s"${endpoint.value}/${container.value}/cromwell-execution" val blobRoot: BlobPath = builder build rootString getOrElse fail() blobRoot.toAbsolutePath.pathAsString should equal ("https://centaurtesting.blob.core.windows.net/test-blob/cromwell-execution") @@ -144,7 +125,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "build a blob path from a test string and read a file" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val endpointHost = BlobPathBuilder.parseURI(endpoint.value).map(_.getHost).getOrElse(fail("Could not parse URI")) val evalPath = "/test/inputFile.txt" val testString = endpoint.value + "/" + container + evalPath @@ -160,7 +141,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "build duplicate blob paths in the same filesystem" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val evalPath = "/test/inputFile.txt" val testString = endpoint.value + "/" + container + evalPath val blobPath1: BlobPath = builder build testString getOrElse fail() @@ -173,7 +154,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "resolve a path without duplicating container name" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val rootString = s"${endpoint.value}/${container.value}/cromwell-execution" val blobRoot: BlobPath = builder build rootString getOrElse fail() blobRoot.toAbsolutePath.pathAsString should equal ("https://centaurtesting.blob.core.windows.net/test-blob/cromwell-execution") @@ -182,7 +163,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "correctly remove a prefix from the blob path" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val rootString = s"${endpoint.value}/${container.value}/cromwell-execution/" val execDirString = s"${endpoint.value}/${container.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/" val fileString = s"${endpoint.value}/${container.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/stdout" @@ -195,7 +176,7 @@ class BlobPathBuilderSpec extends AnyFlatSpec with Matchers with MockSugar { } it should "not change a path if it doesn't start with a prefix" in { - val builder = makeBlobPathBuilder(endpoint, container) + val builder = makeBlobPathBuilder() val otherRootString = s"${endpoint.value}/${container.value}/foobar/" val fileString = s"${endpoint.value}/${container.value}/cromwell-execution/abc123/myworkflow/task1/def4356/execution/stdout" val otherBlobRoot: BlobPath = builder build otherRootString getOrElse fail() From f64b3b97b53cc976ef7e46ebea042da0362a2ef8 Mon Sep 17 00:00:00 2001 From: Justin Variath Thomas Date: Tue, 22 Aug 2023 13:19:48 -0400 Subject: [PATCH 11/25] WX-1174 Adjust NIO Copy functionality (#7207) Co-authored-by: Adam Nichols --- .../storage/blob/nio/AzureFileSystem.java | 11 ++++++++ .../blob/nio/AzureFileSystemProvider.java | 27 ++++++++++++++----- ...Repo template_ Cromwell server TES.run.xml | 3 ++- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java index 381ed0289d7..862352b06ee 100644 --- a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java +++ b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java @@ -170,6 +170,8 @@ public final class AzureFileSystem extends FileSystem { private final Integer downloadResumeRetries; private FileStore defaultFileStore; private boolean closed; + + private AzureSasCredential currentActiveSasCredential; private Instant expiry; AzureFileSystem(AzureFileSystemProvider parentFileSystemProvider, String endpoint, Map config) @@ -188,6 +190,7 @@ public final class AzureFileSystem extends FileSystem { this.putBlobThreshold = (Long) config.get(AZURE_STORAGE_PUT_BLOB_THRESHOLD); this.maxConcurrencyPerRequest = (Integer) config.get(AZURE_STORAGE_MAX_CONCURRENCY_PER_REQUEST); this.downloadResumeRetries = (Integer) config.get(AZURE_STORAGE_DOWNLOAD_RESUME_RETRIES); + this.currentActiveSasCredential = (AzureSasCredential) config.get(AZURE_STORAGE_SAS_TOKEN_CREDENTIAL); // Initialize and ensure access to FileStores. this.defaultFileStore = this.initializeFileStore(config); @@ -496,6 +499,13 @@ Integer getMaxConcurrencyPerRequest() { return this.maxConcurrencyPerRequest; } + public String createSASAppendedURL(String url) throws IllegalStateException { + if (Objects.isNull(currentActiveSasCredential)) { + throw new IllegalStateException("No current active SAS credential present"); + } + return url + "?" + currentActiveSasCredential.getSignature(); + } + public Optional getExpiry() { return Optional.ofNullable(expiry); } @@ -514,5 +524,6 @@ public boolean isExpired(Duration buffer) { return Optional.ofNullable(this.expiry) .map(e -> Instant.now().plus(buffer).isAfter(e)) .orElse(true); + } } diff --git a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystemProvider.java b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystemProvider.java index 6881341d218..2066acf89d5 100644 --- a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystemProvider.java +++ b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystemProvider.java @@ -47,6 +47,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; import java.util.function.Supplier; +import java.util.stream.Collectors; import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; @@ -695,16 +696,23 @@ public void copy(Path source, Path destination, CopyOption... copyOptions) throw // Remove accepted options as we find them. Anything left we don't support. boolean replaceExisting = false; List optionsList = new ArrayList<>(Arrays.asList(copyOptions)); - if (!optionsList.contains(StandardCopyOption.COPY_ATTRIBUTES)) { - throw LoggingUtility.logError(ClientLoggerHolder.LOGGER, new UnsupportedOperationException( - "StandardCopyOption.COPY_ATTRIBUTES must be specified as the service will always copy " - + "file attributes.")); +// NOTE: We're going to assume COPY_ATTRIBUTES as a default copy option (but can still be provided and handled safely) +// REPLACE_EXISTING must still be provided if you want to replace existing file + +// if (!optionsList.contains(StandardCopyOption.COPY_ATTRIBUTES)) { +// throw LoggingUtility.logError(ClientLoggerHolder.LOGGER, new UnsupportedOperationException( +// "StandardCopyOption.COPY_ATTRIBUTES must be specified as the service will always copy " +// + "file attributes.")); +// } + if(optionsList.contains(StandardCopyOption.COPY_ATTRIBUTES)) { + optionsList.remove(StandardCopyOption.COPY_ATTRIBUTES); } - optionsList.remove(StandardCopyOption.COPY_ATTRIBUTES); + if (optionsList.contains(StandardCopyOption.REPLACE_EXISTING)) { replaceExisting = true; optionsList.remove(StandardCopyOption.REPLACE_EXISTING); } + if (!optionsList.isEmpty()) { throw LoggingUtility.logError(ClientLoggerHolder.LOGGER, new UnsupportedOperationException("Unsupported copy option found. Only " @@ -760,9 +768,16 @@ public void copy(Path source, Path destination, CopyOption... copyOptions) throw customer scenarios and how many virtual directories they copy, it could be better to check the directory status first and then do a copy or createDir, which would always be two requests for all resource types. */ + try { + /* + Format the url by appending the SAS token as a param, otherwise the copy request will fail. + AzureFileSystem has been updated to handle url transformation via createSASAuthorizedURL() + */ + AzureFileSystem afs = (AzureFileSystem) sourceRes.getPath().getFileSystem(); + String sasAppendedSourceUrl = afs.createSASAppendedURL(sourceRes.getBlobClient().getBlobUrl()); SyncPoller pollResponse = - destinationRes.getBlobClient().beginCopy(sourceRes.getBlobClient().getBlobUrl(), null, null, null, + destinationRes.getBlobClient().beginCopy(sasAppendedSourceUrl, null, null, null, null, requestConditions, null); pollResponse.waitForCompletion(Duration.ofSeconds(COPY_TIMEOUT_SECONDS)); } catch (BlobStorageException e) { diff --git a/runConfigurations/Repo template_ Cromwell server TES.run.xml b/runConfigurations/Repo template_ Cromwell server TES.run.xml index 9cb1818a004..32127027ecc 100644 --- a/runConfigurations/Repo template_ Cromwell server TES.run.xml +++ b/runConfigurations/Repo template_ Cromwell server TES.run.xml @@ -1,6 +1,6 @@ - \ No newline at end of file From aea7343d3b256fcb42741d96dff99d0ce1ac0945 Mon Sep 17 00:00:00 2001 From: Tom Wiseman Date: Thu, 24 Aug 2023 14:43:45 -0400 Subject: [PATCH 12/25] [WX-1168] TES Log Paths (#7210) --- .../impl/tes/TesJobCachingActorHelper.scala | 1 + .../cromwell/backend/impl/tes/TesJobPaths.scala | 14 +++++++++++--- .../backend/impl/tes/TesJobPathsSpec.scala | 5 +++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobCachingActorHelper.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobCachingActorHelper.scala index d9be8f77b5d..6c4b495d362 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobCachingActorHelper.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobCachingActorHelper.scala @@ -19,5 +19,6 @@ trait TesJobCachingActorHelper extends StandardCachingActorHelper { lazy val tesConfiguration: TesConfiguration = initializationData.tesConfiguration lazy val runtimeAttributes = TesRuntimeAttributes(validatedRuntimeAttributes, jobDescriptor.runtimeAttributes, tesConfiguration) + override protected def nonStandardMetadata: Map[String, Any] = super.nonStandardMetadata ++ tesJobPaths.azureLogPathsForMetadata } diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala index a624a16328a..ab7b2b27916 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesJobPaths.scala @@ -30,9 +30,10 @@ case class TesJobPaths private[tes] (override val workflowPaths: TesWorkflowPath val callInputsRoot = callRoot.resolve("inputs") /* - * tesTaskRoot: This is the root directory that TES will use for files related to this task. - * TES expects a path relative to the root of the storage container. - * We provide it to TES as a k/v pair where the key is "internal_path_prefix" and the value is the relative path string. + * tesTaskRoot: The Azure TES implementation allows us to specify a working directory that it should use for its own files. + * Once the task finishes, this directory will contain stderr.txt, stdout.txt, and some other misc files. + * Unlike other paths we provide, this one is expected as a path relative to the root of the storage container. + * We provide it as a k/v pair where the key is "internal_path_prefix" and the value is the relative path string. * This is not a standard TES feature, but rather related to the Azure TES implementation that Terra uses. * While passing it outside of terra won't do any harm, we could consider making this optional and/or configurable. */ @@ -42,6 +43,13 @@ case class TesJobPaths private[tes] (override val workflowPaths: TesWorkflowPath case anyOtherPath: Path => anyOtherPath.pathAsString } + // Like above: Nothing should rely on these files existing, since only the Azure TES implementation will actually create them. + // Used to send the Azure TES log paths to the frontend. + val azureLogPathsForMetadata : Map[String, Any] = Map( + "tes_stdout" -> taskFullPath./("stdout.txt").pathAsString, + "tes_stderr" -> taskFullPath./("stderr.txt").pathAsString + ) + // Given an output path, return a path localized to the storage file system def storageOutput(path: String): String = { callExecutionRoot.resolve(path).toString diff --git a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesJobPathsSpec.scala b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesJobPathsSpec.scala index 177331445c9..36bdd2a5015 100644 --- a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesJobPathsSpec.scala +++ b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesJobPathsSpec.scala @@ -35,6 +35,11 @@ class TesJobPathsSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matchers File(s"/cromwell-executions/wf_hello/$id/call-hello").pathAsString jobPaths.callExecutionDockerRoot.toString shouldBe File(s"/cromwell-executions/wf_hello/$id/call-hello/execution").pathAsString + jobPaths.azureLogPathsForMetadata shouldBe + Map( + "tes_stdout" -> File(s"local-cromwell-executions/wf_hello/$id/call-hello/tes_task/stdout.txt").pathAsString, + "tes_stderr" -> File(s"local-cromwell-executions/wf_hello/$id/call-hello/tes_task/stderr.txt").pathAsString + ) val jobKeySharded = BackendJobDescriptorKey(call, Option(0), 1) val jobPathsSharded = TesJobPaths(jobKeySharded, wd, TesTestConfig.backendConfig) From 33e991f6960ddd37c2a4ed14b4d610e275b8d90b Mon Sep 17 00:00:00 2001 From: Janet Gainer-Dewar Date: Fri, 15 Sep 2023 11:50:07 -0400 Subject: [PATCH 13/25] WX-1264 Don't expire an unexpirable filesystem (#7216) --- .../storage/blob/nio/AzureFileSystem.java | 7 ++- .../blob/BlobFileSystemManager.scala | 19 +++---- .../blob/AzureFileSystemSpec.scala | 56 +++++++++++++++---- .../blob/BlobPathBuilderFactorySpec.scala | 7 --- 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java index 862352b06ee..8ca4361bd3e 100644 --- a/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java +++ b/azure-blob-nio/src/main/java/com/azure/storage/blob/nio/AzureFileSystem.java @@ -520,10 +520,15 @@ private void setExpiryFromSAS(AzureSasCredential token) { this.expiry = expiryString.map(es -> Instant.parse(es)).orElse(null); } + /** + * Return true if this filesystem has SAS credentials with an expiration data attached, and we're within + * `buffer` of the expiration. Return false if our credentials don't come with an expiration, or we + * aren't within `buffer` of our expiration. + */ public boolean isExpired(Duration buffer) { return Optional.ofNullable(this.expiry) .map(e -> Instant.now().plus(buffer).isAfter(e)) - .orElse(true); + .orElse(false); } } diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala index 6b6088c7689..e3de6783d85 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala @@ -13,7 +13,7 @@ import java.net.URI import java.nio.file._ import java.nio.file.spi.FileSystemProvider import java.time.temporal.ChronoUnit -import java.time.{Duration, Instant, OffsetDateTime} +import java.time.{Duration, OffsetDateTime} import java.util.UUID import scala.jdk.CollectionConverters._ import scala.util.{Failure, Success, Try} @@ -32,10 +32,6 @@ case class AzureFileSystemAPI(private val provider: FileSystemProvider = new Azu * See BlobSasTokenGenerator for more information on how a SAS token is generated */ object BlobFileSystemManager { - def parseTokenExpiry(token: AzureSasCredential): Option[Instant] = for { - expiryString <- token.getSignature.split("&").find(_.startsWith("se")).map(_.replaceFirst("se=","")).map(_.replace("%3A", ":")) - instant = Instant.parse(expiryString) - } yield instant def buildConfigMap(credential: AzureSasCredential, container: BlobContainerName): Map[String, Object] = { // Special handling is done here to provide a special key value pair if the placeholder token is provided @@ -226,9 +222,12 @@ case class NativeBlobSasTokenGenerator(subscription: Option[SubscriptionId] = No * * @return an AzureSasCredential for accessing a blob container */ - def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] = for { - bcc <- AzureUtils.buildContainerClientFromLocalEnvironment(container.toString, endpoint.toString, subscription.map(_.toString)) - bsssv = new BlobServiceSasSignatureValues(OffsetDateTime.now.plusDays(1), bcsp) - asc = new AzureSasCredential(bcc.generateSas(bsssv)) - } yield asc + def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] = { + val c = AzureUtils.buildContainerClientFromLocalEnvironment(container.toString, endpoint.toString, subscription.map(_.toString)) + + c.map { bcc => + val bsssv = new BlobServiceSasSignatureValues(OffsetDateTime.now.plusDays(1), bcsp) + new AzureSasCredential(bcc.generateSas(bsssv)) + }.orElse(Try(BlobFileSystemManager.PLACEHOLDER_TOKEN)) + } } diff --git a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala index e0463bab740..9b8362ced80 100644 --- a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala +++ b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/AzureFileSystemSpec.scala @@ -1,25 +1,61 @@ package cromwell.filesystems.blob +import com.azure.core.credential.AzureSasCredential import com.azure.storage.blob.nio.{AzureFileSystem, AzureFileSystemProvider} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import java.time.Instant +import java.time.{Duration, Instant} +import java.time.temporal.ChronoUnit import scala.compat.java8.OptionConverters._ import scala.jdk.CollectionConverters._ class AzureFileSystemSpec extends AnyFlatSpec with Matchers { - val now = Instant.now() - val container = BlobContainerName("testConainer") - val exampleSas = BlobPathBuilderFactorySpec.buildExampleSasToken(now) - val exampleConfig = BlobFileSystemManager.buildConfigMap(exampleSas, container) - val exampleStorageEndpoint = BlobPathBuilderSpec.buildEndpoint("testStorageAccount") - val exampleCombinedEndpoint = BlobFileSystemManager.combinedEnpointContainerUri(exampleStorageEndpoint, container) - it should "parse an expiration from a sas token" in { + val fiveMinutes: Duration = Duration.of(5, ChronoUnit.MINUTES) + + private def makeFilesystemWithExpiration(expiration: Instant): AzureFileSystem = + makeFilesystemWithCreds(BlobPathBuilderFactorySpec.buildExampleSasToken(expiration)) + + private def makeFilesystemWithCreds(creds: AzureSasCredential): AzureFileSystem = { + val storageEndpoint = BlobPathBuilderSpec.buildEndpoint("testStorageAccount") + val container = BlobContainerName("testContainer") + val combinedEndpoint = BlobFileSystemManager.combinedEnpointContainerUri(storageEndpoint, container) + val provider = new AzureFileSystemProvider() - val filesystem : AzureFileSystem = provider.newFileSystem(exampleCombinedEndpoint, exampleConfig.asJava).asInstanceOf[AzureFileSystem] + provider.newFileSystem( + combinedEndpoint, + BlobFileSystemManager.buildConfigMap(creds, container).asJava + ).asInstanceOf[AzureFileSystem] + } + + it should "parse an expiration from a sas token" in { + val now = Instant.now() + val filesystem : AzureFileSystem = makeFilesystemWithExpiration(now) filesystem.getExpiry.asScala shouldBe Some(now) - filesystem.getFileStores.asScala.map(_.name()).exists(_ == container.value) shouldBe true + filesystem.getFileStores.asScala.map(_.name()).exists(_ == "testContainer") shouldBe true + } + + it should "not be expired when the token is fresh" in { + val anHourFromNow = Instant.now().plusSeconds(3600) + val filesystem : AzureFileSystem = makeFilesystemWithExpiration(anHourFromNow) + filesystem.isExpired(fiveMinutes) shouldBe false + } + + it should "be expired when we're within the buffer" in { + val threeMinutesFromNow = Instant.now().plusSeconds(180) + val filesystem : AzureFileSystem = makeFilesystemWithExpiration(threeMinutesFromNow) + filesystem.isExpired(fiveMinutes) shouldBe true + } + + it should "be expired when the token is stale" in { + val anHourAgo = Instant.now().minusSeconds(3600) + val filesystem : AzureFileSystem = makeFilesystemWithExpiration(anHourAgo) + filesystem.isExpired(fiveMinutes) shouldBe true + } + + it should "not be expired with public credentials" in { + val fileSystem = makeFilesystemWithCreds(BlobFileSystemManager.PLACEHOLDER_TOKEN) + fileSystem.isExpired(fiveMinutes) shouldBe false } } diff --git a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala index c4ee102c58b..24783c15780 100644 --- a/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala +++ b/filesystems/blob/src/test/scala/cromwell/filesystems/blob/BlobPathBuilderFactorySpec.scala @@ -32,13 +32,6 @@ class BlobPathBuilderFactorySpec extends AnyFlatSpec with Matchers with MockSuga testToken.getSignature should equal(sourceToken) } - it should "parse an expiration time from a sas token" in { - val expiryTime = generateTokenExpiration(20L) - val sasToken = BlobPathBuilderFactorySpec.buildExampleSasToken(expiryTime) - val expiry = BlobFileSystemManager.parseTokenExpiry(sasToken) - expiry should contain(expiryTime) - } - it should "test that a filesystem gets closed correctly" in { val endpoint = BlobPathBuilderSpec.buildEndpoint("storageAccount") val container = BlobContainerName("test") From a290e6d513ca953a23a8e77096e1e6eeb367e00e Mon Sep 17 00:00:00 2001 From: Tom Wiseman Date: Fri, 15 Sep 2023 15:32:40 -0400 Subject: [PATCH 14/25] [WX-495] DRS Parallel Downloads (#7214) --- .../drs/localizer/DrsLocalizerMain.scala | 242 ++++++---- .../downloaders/AccessUrlDownloader.scala | 91 ---- .../downloaders/BulkAccessUrlDownloader.scala | 144 ++++++ .../downloaders/DownloaderFactory.scala | 8 +- .../downloaders/GcsUriDownloader.scala | 43 +- .../localizer/downloaders/GetmChecksum.scala | 2 +- .../drs/localizer/DrsLocalizerMainSpec.scala | 415 ++++++++---------- .../downloaders/AccessUrlDownloaderSpec.scala | 59 --- .../BulkAccessUrlDownloaderSpec.scala | 114 +++++ ...o template_ Cromwell DRS Localizer.run.xml | 2 +- 10 files changed, 653 insertions(+), 467 deletions(-) delete mode 100644 cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/AccessUrlDownloader.scala create mode 100644 cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/BulkAccessUrlDownloader.scala delete mode 100644 cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/AccessUrlDownloaderSpec.scala create mode 100644 cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/BulkAccessUrlDownloaderSpec.scala diff --git a/cromwell-drs-localizer/src/main/scala/drs/localizer/DrsLocalizerMain.scala b/cromwell-drs-localizer/src/main/scala/drs/localizer/DrsLocalizerMain.scala index 1858f395024..3d99538f614 100644 --- a/cromwell-drs-localizer/src/main/scala/drs/localizer/DrsLocalizerMain.scala +++ b/cromwell-drs-localizer/src/main/scala/drs/localizer/DrsLocalizerMain.scala @@ -2,13 +2,13 @@ package drs.localizer import cats.data.NonEmptyList import cats.effect.{ExitCode, IO, IOApp} -import cats.implicits._ +import cats.implicits.toTraverseOps import cloud.nio.impl.drs.DrsPathResolver.{FatalRetryDisposition, RegularRetryDisposition} import cloud.nio.impl.drs._ import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff} import com.typesafe.scalalogging.StrictLogging import drs.localizer.CommandLineParser.AccessTokenStrategy.{Azure, Google} -import drs.localizer.downloaders.AccessUrlDownloader.Hashes +import drs.localizer.DrsLocalizerMain.{defaultNumRetries, toValidatedUriType} import drs.localizer.downloaders._ import org.apache.commons.csv.{CSVFormat, CSVParser} @@ -17,7 +17,10 @@ import java.nio.charset.Charset import scala.concurrent.duration._ import scala.jdk.CollectionConverters._ import scala.language.postfixOps +import drs.localizer.URIType.URIType +case class UnresolvedDrsUrl(drsUrl: String, downloadDestinationPath: String) +case class ResolvedDrsUrl(drsResponse: DrsResolverResponse, downloadDestinationPath: String, uriType: URIType) object DrsLocalizerMain extends IOApp with StrictLogging { override def run(args: List[String]): IO[ExitCode] = { @@ -38,15 +41,18 @@ object DrsLocalizerMain extends IOApp with StrictLogging { def buildParser(): scopt.OptionParser[CommandLineArguments] = new CommandLineParser() + // Default retry parameters for resolving a DRS url + val defaultNumRetries: Int = 5 val defaultBackoff: CloudNioBackoff = CloudNioSimpleExponentialBackoff( - initialInterval = 10 seconds, maxInterval = 60 seconds, multiplier = 2) + initialInterval = 1 seconds, maxInterval = 60 seconds, multiplier = 2) val defaultDownloaderFactory: DownloaderFactory = new DownloaderFactory { - override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = - IO.pure(AccessUrlDownloader(accessUrl, downloadLoc, hashes)) + override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): Downloader = + GcsUriDownloader(gcsPath, serviceAccountJsonOption, downloadLoc, requesterPaysProjectOption) - override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = - IO.pure(GcsUriDownloader(gcsPath, serviceAccountJsonOption, downloadLoc, requesterPaysProjectOption)) + override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): Downloader = { + BulkAccessUrlDownloader(urlsToDownload) + } } private def printUsage: IO[ExitCode] = { @@ -54,35 +60,83 @@ object DrsLocalizerMain extends IOApp with StrictLogging { IO.pure(ExitCode.Error) } - def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials): IO[ExitCode] = { - commandLineArguments.manifestPath match { - case Some(manifestPath) => - val manifestFile = new File(manifestPath) - val csvParser = CSVParser.parse(manifestFile, Charset.defaultCharset(), CSVFormat.DEFAULT) - val exitCodes: IO[List[ExitCode]] = csvParser.asScala.map(record => { - val drsObject = record.get(0) - val containerPath = record.get(1) - localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath) - }).toList.sequence - exitCodes.map(_.find(_ != ExitCode.Success).getOrElse(ExitCode.Success)) - case None => - val drsObject = commandLineArguments.drsObject.get - val containerPath = commandLineArguments.containerPath.get - localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath) + /** + * Helper function to read a CSV file as pairs of drsURL -> local download destination. + * @param csvManifestPath Path to a CSV file where each row is something like: drs://asdf.ghj, path/to/my/directory + */ + def loadCSVManifest(csvManifestPath: String): IO[List[UnresolvedDrsUrl]] = { + IO { + val openFile = new File(csvManifestPath) + val csvParser = CSVParser.parse(openFile, Charset.defaultCharset(), CSVFormat.DEFAULT) + try{ + csvParser.getRecords.asScala.map(record => UnresolvedDrsUrl(record.get(0), record.get(1))).toList + } finally { + csvParser.close() + } } } - private def localizeFile(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials, drsObject: String, containerPath: String) = { - new DrsLocalizerMain(drsObject, containerPath, drsCredentials, commandLineArguments.googleRequesterPaysProject). - resolveAndDownloadWithRetries(downloadRetries = 3, checksumRetries = 1, defaultDownloaderFactory, Option(defaultBackoff)).map(_.exitCode) + def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials) : IO[ExitCode] = { + val urlList = (commandLineArguments.manifestPath, commandLineArguments.drsObject, commandLineArguments.containerPath) match { + case (Some(manifestPath), _, _) => { + loadCSVManifest(manifestPath) + } + case (_, Some(drsObject), Some(containerPath)) => { + IO.pure(List(UnresolvedDrsUrl(drsObject, containerPath))) + } + case(_,_,_) => { + throw new RuntimeException("Illegal command line arguments supplied to drs localizer.") + } + } + val main = new DrsLocalizerMain(urlList, defaultDownloaderFactory, drsCredentials, commandLineArguments.googleRequesterPaysProject) + main.resolveAndDownload().map(_.exitCode) + } + + /** + * Helper function to decide which downloader to use based on data from the DRS response. + * Throws a runtime exception if the DRS response is invalid. + */ + def toValidatedUriType(accessUrl: Option[AccessUrl], gsUri: Option[String]): URIType = { + // if both are provided, prefer using access urls + (accessUrl, gsUri) match { + case (Some(_), _) => + if(!accessUrl.get.url.startsWith("https://")) { throw new RuntimeException("Resolved Access URL does not start with https://")} + URIType.ACCESS + case (_, Some(_)) => + if(!gsUri.get.startsWith("gs://")) { throw new RuntimeException("Resolved Google URL does not start with gs://")} + URIType.GCS + case (_, _) => + throw new RuntimeException("DRS response did not contain any URLs") + } } + } + +object URIType extends Enumeration { + type URIType = Value + val GCS, ACCESS, UNKNOWN = Value } -class DrsLocalizerMain(drsUrl: String, - downloadLoc: String, +class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]], + downloaderFactory: DownloaderFactory, drsCredentials: DrsCredentials, requesterPaysProjectIdOption: Option[String]) extends StrictLogging { + /** + * This will: + * - resolve all URLS + * - build downloader(s) for them + * - Invoke the downloaders to localize the files. + * @return DownloadSuccess if all downloads succeed. An error otherwise. + */ + def resolveAndDownload(): IO[DownloadResult] = { + val downloadResults = buildDownloaders().flatMap { downloaderList => + downloaderList.map(downloader => downloader.download).traverse(identity) + } + downloadResults.map{list => + list.find(result => result != DownloadSuccess).getOrElse(DownloadSuccess) + } + } + def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = { IO { val drsConfig = DrsConfig.fromEnv(sys.env) @@ -91,76 +145,86 @@ class DrsLocalizerMain(drsUrl: String, } } - def resolveAndDownloadWithRetries(downloadRetries: Int, - checksumRetries: Int, - downloaderFactory: DownloaderFactory, - backoff: Option[CloudNioBackoff], - downloadAttempt: Int = 0, - checksumAttempt: Int = 0): IO[DownloadResult] = { - - def maybeRetryForChecksumFailure(t: Throwable): IO[DownloadResult] = { - if (checksumAttempt < checksumRetries) { - backoff foreach { b => Thread.sleep(b.backoffMillis) } - logger.warn(s"Attempting retry $checksumAttempt of $checksumRetries checksum retries to download $drsUrl", t) - // In the event of a checksum failure reset the download attempt to zero. - resolveAndDownloadWithRetries(downloadRetries, checksumRetries, downloaderFactory, backoff map { _.next }, 0, checksumAttempt + 1) - } else { - IO.raiseError(new RuntimeException(s"Exhausted $checksumRetries checksum retries to resolve, download and checksum $drsUrl", t)) - } + /** + * After resolving all of the URLs, this sorts them into an "Access" or "GCS" bucket. + * All access URLS will be downloaded as a batch with a single bulk downloader. + * All google URLs will be downloaded individually in their own google downloader. + * @return List of all downloaders required to fulfill the request. + */ + def buildDownloaders() : IO[List[Downloader]] = { + resolveUrls(toResolveAndDownload).map { pendingDownloads => + val accessUrls = pendingDownloads.filter(url => url.uriType == URIType.ACCESS) + val googleUrls = pendingDownloads.filter(url => url.uriType == URIType.GCS) + val bulkDownloader: List[Downloader] = if (accessUrls.isEmpty) List() else List(buildBulkAccessUrlDownloader(accessUrls)) + val googleDownloaders: List[Downloader] = if (googleUrls.isEmpty) List() else buildGoogleDownloaders(googleUrls) + bulkDownloader ++ googleDownloaders } + } - def maybeRetryForDownloadFailure(t: Throwable): IO[DownloadResult] = { - t match { - case _: FatalRetryDisposition => - IO.raiseError(t) - case _ if downloadAttempt < downloadRetries => - backoff foreach { b => Thread.sleep(b.backoffMillis) } - logger.warn(s"Attempting retry $downloadAttempt of $downloadRetries download retries to download $drsUrl", t) - resolveAndDownloadWithRetries(downloadRetries, checksumRetries, downloaderFactory, backoff map { _.next }, downloadAttempt + 1, checksumAttempt) - case _ => - IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries download retries to resolve, download and checksum $drsUrl", t)) - } + def buildGoogleDownloaders(resolvedGoogleUrls: List[ResolvedDrsUrl]) : List[Downloader] = { + resolvedGoogleUrls.map{url=> + downloaderFactory.buildGcsUriDownloader( + gcsPath = url.drsResponse.gsUri.get, + serviceAccountJsonOption = url.drsResponse.googleServiceAccount.map(_.data.spaces2), + downloadLoc = url.downloadDestinationPath, + requesterPaysProjectOption = requesterPaysProjectIdOption) } + } + def buildBulkAccessUrlDownloader(resolvedUrls: List[ResolvedDrsUrl]) : Downloader = { + downloaderFactory.buildBulkAccessUrlDownloader(resolvedUrls) + } - resolveAndDownload(downloaderFactory).redeemWith({ - maybeRetryForDownloadFailure - }, - { - case f: FatalDownloadFailure => - IO.raiseError(new RuntimeException(s"Fatal error downloading DRS object: $f")) - case r: RetryableDownloadFailure => - maybeRetryForDownloadFailure( - new RuntimeException(s"Retryable download error: $r for $drsUrl on retry attempt $downloadAttempt of $downloadRetries") with RegularRetryDisposition) - case ChecksumFailure => - maybeRetryForChecksumFailure(new RuntimeException(s"Checksum failure for $drsUrl on checksum retry attempt $checksumAttempt of $checksumRetries")) - case o => IO.pure(o) - }) + /** + * Runs a synchronous HTTP request to resolve the provided DRS URL with the provided resolver. + */ + def resolveSingleUrl(resolverObject: DrsLocalizerDrsPathResolver, drsUrlToResolve: UnresolvedDrsUrl): IO[ResolvedDrsUrl] = { + val fields = NonEmptyList.of(DrsResolverField.GsUri, DrsResolverField.GoogleServiceAccount, DrsResolverField.AccessUrl, DrsResolverField.Hashes) + val drsResponse = resolverObject.resolveDrs(drsUrlToResolve.drsUrl, fields) + drsResponse.map(resp => ResolvedDrsUrl(resp, drsUrlToResolve.downloadDestinationPath, toValidatedUriType(resp.accessUrl, resp.gsUri))) } - private [localizer] def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = { - resolve(downloaderFactory) flatMap { _.download } + + val defaultBackoff: CloudNioBackoff = CloudNioSimpleExponentialBackoff( + initialInterval = 10 seconds, maxInterval = 60 seconds, multiplier = 2) + + /** + * Runs synchronous HTTP requests to resolve all the DRS urls. + */ + def resolveUrls(unresolvedUrls: IO[List[UnresolvedDrsUrl]]): IO[List[ResolvedDrsUrl]] = { + unresolvedUrls.flatMap { unresolvedList => + getDrsPathResolver.flatMap { resolver => + unresolvedList.map { unresolvedUrl => + resolveWithRetries(resolver, unresolvedUrl, defaultNumRetries, Option(defaultBackoff)) + }.traverse(identity) + } + } } - private [localizer] def resolve(downloaderFactory: DownloaderFactory): IO[Downloader] = { - val fields = NonEmptyList.of(DrsResolverField.GsUri, DrsResolverField.GoogleServiceAccount, DrsResolverField.AccessUrl, DrsResolverField.Hashes) - for { - resolver <- getDrsPathResolver - drsResolverResponse <- resolver.resolveDrs(drsUrl, fields) - - // Currently DRS Resolver only supports resolving DRS paths to access URLs or GCS paths. - downloader <- (drsResolverResponse.accessUrl, drsResolverResponse.gsUri) match { - case (Some(accessUrl), _) => - downloaderFactory.buildAccessUrlDownloader(accessUrl, downloadLoc, drsResolverResponse.hashes) - case (_, Some(gcsPath)) => - val serviceAccountJsonOption = drsResolverResponse.googleServiceAccount.map(_.data.spaces2) - downloaderFactory.buildGcsUriDownloader( - gcsPath = gcsPath, - serviceAccountJsonOption = serviceAccountJsonOption, - downloadLoc = downloadLoc, - requesterPaysProjectOption = requesterPaysProjectIdOption) - case _ => - IO.raiseError(new RuntimeException(DrsPathResolver.ExtractUriErrorMsg)) + def resolveWithRetries(resolverObject: DrsLocalizerDrsPathResolver, + drsUrlToResolve: UnresolvedDrsUrl, + resolutionRetries: Int, + backoff: Option[CloudNioBackoff], + resolutionAttempt: Int = 0) : IO[ResolvedDrsUrl] = { + + def maybeRetryForResolutionFailure(t: Throwable): IO[ResolvedDrsUrl] = { + if (resolutionAttempt < resolutionRetries) { + backoff foreach { b => Thread.sleep(b.backoffMillis) } + logger.warn(s"Attempting retry $resolutionAttempt of $resolutionRetries drs resolution retries to resolve ${drsUrlToResolve.drsUrl}", t) + resolveWithRetries(resolverObject, drsUrlToResolve, resolutionRetries, backoff map { _.next }, resolutionAttempt+1) + } else { + IO.raiseError(new RuntimeException(s"Exhausted $resolutionRetries resolution retries to resolve $drsUrlToResolve.drsUrl", t)) } - } yield downloader + } + + resolveSingleUrl(resolverObject, drsUrlToResolve).redeemWith( + recover = maybeRetryForResolutionFailure, + bind = { + case f: FatalRetryDisposition => + IO.raiseError(new RuntimeException(s"Fatal error resolving DRS URL: $f")) + case _: RegularRetryDisposition => + resolveWithRetries(resolverObject, drsUrlToResolve, resolutionRetries, backoff, resolutionAttempt+1) + case o => IO.pure(o) + }) } } + diff --git a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/AccessUrlDownloader.scala b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/AccessUrlDownloader.scala deleted file mode 100644 index ae6f2fa4f1e..00000000000 --- a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/AccessUrlDownloader.scala +++ /dev/null @@ -1,91 +0,0 @@ -package drs.localizer.downloaders - -import cats.data.Validated.{Invalid, Valid} -import cats.effect.{ExitCode, IO} -import cloud.nio.impl.drs.AccessUrl -import com.typesafe.scalalogging.StrictLogging -import common.exception.AggregatedMessageException -import common.util.StringUtil._ -import common.validation.ErrorOr.ErrorOr -import drs.localizer.downloaders.AccessUrlDownloader._ - -import scala.sys.process.{Process, ProcessLogger} -import scala.util.matching.Regex - -case class GetmResult(returnCode: Int, stderr: String) - -case class AccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes) extends Downloader with StrictLogging { - def generateDownloadScript: ErrorOr[String] = { - val signedUrl = accessUrl.url - GetmChecksum(hashes, accessUrl).args map { checksumArgs => - s"""mkdir -p $$(dirname '$downloadLoc') && rm -f '$downloadLoc' && getm $checksumArgs --filepath '$downloadLoc' '$signedUrl'""" - } - } - - def runGetm: IO[GetmResult] = { - generateDownloadScript match { - case Invalid(errors) => - IO.raiseError(AggregatedMessageException("Error generating access URL download script", errors.toList)) - case Valid(script) => IO { - val copyCommand = Seq("bash", "-c", script) - val copyProcess = Process(copyCommand) - - val stderr = new StringBuilder() - val errorCapture: String => Unit = { s => stderr.append(s); () } - - // As of `getm` version 0.0.4 the contents of stdout do not appear to be interesting (only a progress bar - // with no option to suppress it), so ignore stdout for now. If stdout becomes interesting in future versions - // of `getm` it can be captured just like stderr is being captured here. - val returnCode = copyProcess ! ProcessLogger(_ => (), errorCapture) - - GetmResult(returnCode, stderr.toString().trim()) - } - } - } - - override def download: IO[DownloadResult] = { - // We don't want to log the unmasked signed URL here. On a PAPI backend this log will end up under the user's - // workspace bucket, but that bucket may have visibility different than the data referenced by the signed URL. - val masked = accessUrl.url.maskSensitiveUri - logger.info(s"Attempting to download data to '$downloadLoc' from access URL '$masked'.") - - runGetm map toDownloadResult - } - - def toDownloadResult(getmResult: GetmResult): DownloadResult = { - getmResult match { - case GetmResult(0, stderr) if stderr.isEmpty => - DownloadSuccess - case GetmResult(0, stderr) => - stderr match { - case ChecksumFailureMessage() => - ChecksumFailure - case _ => - UnrecognizedRetryableDownloadFailure(ExitCode(0)) - } - case GetmResult(rc, stderr) => - stderr match { - case HttpStatusMessage(status) => - Integer.parseInt(status) match { - case 408 | 429 => - RecognizedRetryableDownloadFailure(ExitCode(rc)) - case s if s / 100 == 4 => - FatalDownloadFailure(ExitCode(rc)) - case s if s / 100 == 5 => - RecognizedRetryableDownloadFailure(ExitCode(rc)) - case _ => - UnrecognizedRetryableDownloadFailure(ExitCode(rc)) - } - case _ => - UnrecognizedRetryableDownloadFailure(ExitCode(rc)) - } - } - } -} - -object AccessUrlDownloader { - type Hashes = Option[Map[String, String]] - - val ChecksumFailureMessage: Regex = raw""".*AssertionError: Checksum failed!.*""".r - val HttpStatusMessage: Regex = raw"""ERROR:getm\.cli.*"status_code":\s*(\d+).*""".r -} diff --git a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/BulkAccessUrlDownloader.scala b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/BulkAccessUrlDownloader.scala new file mode 100644 index 00000000000..4668c5072ed --- /dev/null +++ b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/BulkAccessUrlDownloader.scala @@ -0,0 +1,144 @@ +package drs.localizer.downloaders + +import cats.effect.{ExitCode, IO} +import cloud.nio.impl.drs.{AccessUrl, DrsResolverResponse} +import com.typesafe.scalalogging.StrictLogging + +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path, Paths} +import scala.sys.process.{Process, ProcessLogger} +import scala.util.matching.Regex +import drs.localizer.ResolvedDrsUrl +case class GetmResult(returnCode: Int, stderr: String) +/** + * Getm is a python tool that is used to download resolved DRS uris quickly and in parallel. + * This class builds a getm-manifest.json file that it uses for input, and builds/executes a shell command + * to invoke the Getm tool, which is expected to already be installed in the local environment. + * @param resolvedUrls + */ +case class BulkAccessUrlDownloader(resolvedUrls : List[ResolvedDrsUrl]) extends Downloader with StrictLogging { + + val getmManifestPath: Path = Paths.get("getm-manifest.json") + /** + * Write a json manifest to disk that looks like: + * // [ + * // { + * // "url" : "www.123.com", + * // "filepath" : "path/to/where/123/should/be/downloaded", + * // "checksum" : "sdfjndsfjkfsdjsdfkjsdf", + * // "checksum-algorithm" : "md5" + * // }, + * // { + * // "url" : "www.567.com" + * // "filepath" : "path/to/where/567/should/be/downloaded", + * // "checksum" : "asdasdasfsdfsdfasdsdfasd", + * // "checksum-algorithm" : "md5" + * // } + * // ] + * + * @param resolvedUrls + * @return Filepath of a getm-manifest.json that Getm can use to download multiple files in parallel. + */ + def generateJsonManifest(resolvedUrls : List[ResolvedDrsUrl]): IO[Path] = { + def toJsonString(drsResponse: DrsResolverResponse, destinationFilepath: String): String = { + //NB: trailing comma is being removed in generateJsonManifest + val accessUrl: AccessUrl = drsResponse.accessUrl.getOrElse(AccessUrl("missing", None)) + drsResponse.hashes.map(_ => { + val checksum = GetmChecksum(drsResponse.hashes, accessUrl).value.getOrElse("error_calculating_checksum") + val checksumAlgorithm = GetmChecksum(drsResponse.hashes, accessUrl).getmAlgorithm + s""" { + | "url" : "${accessUrl.url}", + | "filepath" : "$destinationFilepath", + | "checksum" : "$checksum", + | "checksum-algorithm" : "$checksumAlgorithm" + | }, + |""".stripMargin + }).getOrElse( + s""" { + | "url" : "${accessUrl.url}", + | "filepath" : "$destinationFilepath" + | }, + |""".stripMargin + ) + } + IO { + var jsonString: String = "[\n" + for (resolvedUrl <- resolvedUrls) { + jsonString += toJsonString(resolvedUrl.drsResponse, resolvedUrl.downloadDestinationPath) + } + if(jsonString.contains(',')) { + //remove trailing comma from array elements, but don't crash on empty list. + jsonString = jsonString.substring(0, jsonString.lastIndexOf(",")) + } + jsonString += "\n]" + Files.write(getmManifestPath, jsonString.getBytes(StandardCharsets.UTF_8)) + } + } + + def deleteJsonManifest() = { + Files.deleteIfExists(getmManifestPath) + } + + def generateGetmCommand(pathToMainfestJson : Path) : String = { + s"""getm --manifest ${pathToMainfestJson.toString}""" + } + def runGetm: IO[GetmResult] = { + generateJsonManifest(resolvedUrls).flatMap{ manifestPath => + val script = generateGetmCommand(manifestPath) + val copyCommand : Seq[String] = Seq("bash", "-c", script) + logger.info(script) + val copyProcess = Process(copyCommand) + val stderr = new StringBuilder() + val errorCapture: String => Unit = { s => stderr.append(s); () } + val returnCode = copyProcess ! ProcessLogger(_ => (), errorCapture) + deleteJsonManifest() + logger.info(stderr.toString().trim()) + IO(GetmResult(returnCode, stderr.toString().trim())) + } + } + + override def download: IO[DownloadResult] = { + // We don't want to log the unmasked signed URL here. On a PAPI backend this log will end up under the user's + // workspace bucket, but that bucket may have visibility different than the data referenced by the signed URL. + logger.info(s"Attempting to download data") + + runGetm map toDownloadResult + } + + def toDownloadResult(getmResult: GetmResult): DownloadResult = { + getmResult match { + case GetmResult(0, stderr) if stderr.isEmpty => + DownloadSuccess + case GetmResult(0, stderr) => + stderr match { + case BulkAccessUrlDownloader.ChecksumFailureMessage() => + ChecksumFailure + case _ => + UnrecognizedRetryableDownloadFailure(ExitCode(0)) + } + case GetmResult(rc, stderr) => + stderr match { + case BulkAccessUrlDownloader.HttpStatusMessage(status) => + Integer.parseInt(status) match { + case 408 | 429 => + RecognizedRetryableDownloadFailure(ExitCode(rc)) + case s if s / 100 == 4 => + FatalDownloadFailure(ExitCode(rc)) + case s if s / 100 == 5 => + RecognizedRetryableDownloadFailure(ExitCode(rc)) + case _ => + UnrecognizedRetryableDownloadFailure(ExitCode(rc)) + } + case _ => + UnrecognizedRetryableDownloadFailure(ExitCode(rc)) + } + } + } +} + +object BulkAccessUrlDownloader{ + type Hashes = Option[Map[String, String]] + + val ChecksumFailureMessage: Regex = raw""".*AssertionError: Checksum failed!.*""".r + val HttpStatusMessage: Regex = raw"""ERROR:getm\.cli.*"status_code":\s*(\d+).*""".r +} diff --git a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/DownloaderFactory.scala b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/DownloaderFactory.scala index 8465ede0dd6..6c7f27e8a6e 100644 --- a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/DownloaderFactory.scala +++ b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/DownloaderFactory.scala @@ -1,14 +1,12 @@ package drs.localizer.downloaders -import cats.effect.IO -import cloud.nio.impl.drs.AccessUrl -import drs.localizer.downloaders.AccessUrlDownloader.Hashes +import drs.localizer.ResolvedDrsUrl trait DownloaderFactory { - def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] + def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]) : Downloader def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, - requesterPaysProjectOption: Option[String]): IO[Downloader] + requesterPaysProjectOption: Option[String]): Downloader } diff --git a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala index d4c81af6300..8991e79f5fd 100644 --- a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala +++ b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala @@ -1,10 +1,12 @@ package drs.localizer.downloaders import cats.effect.{ExitCode, IO} +import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff} import com.typesafe.scalalogging.StrictLogging import drs.localizer.downloaders.GcsUriDownloader.RequesterPaysErrorMsg - +import scala.language.postfixOps import java.nio.charset.StandardCharsets import java.nio.file.{Files, Path} +import scala.concurrent.duration.DurationInt import scala.sys.process.{Process, ProcessLogger} case class GcsUriDownloader(gcsUrl: String, @@ -12,7 +14,15 @@ case class GcsUriDownloader(gcsUrl: String, downloadLoc: String, requesterPaysProjectIdOption: Option[String]) extends Downloader with StrictLogging { + val defaultNumRetries: Int = 5 + val defaultBackoff: CloudNioBackoff = CloudNioSimpleExponentialBackoff( + initialInterval = 1 seconds, maxInterval = 60 seconds, multiplier = 2) + override def download: IO[DownloadResult] = { + downloadWithRetries(defaultNumRetries, Option(defaultBackoff)) + } + + def runDownloadCommand: IO[DownloadResult] = { logger.info(s"Requester Pays project ID is $requesterPaysProjectIdOption") logger.info(s"Attempting to download $gcsUrl to $downloadLoc") @@ -40,6 +50,37 @@ case class GcsUriDownloader(gcsUrl: String, IO.pure(result) } + def downloadWithRetries(downloadRetries: Int, + backoff: Option[CloudNioBackoff], + downloadAttempt: Int = 0): IO[DownloadResult] = + { + + def maybeRetryForDownloadFailure(t: Throwable): IO[DownloadResult] = { + if (downloadAttempt < downloadRetries) { + backoff foreach { b => Thread.sleep(b.backoffMillis) } + logger.warn(s"Attempting download retry $downloadAttempt of $downloadRetries for a GCS url", t) + downloadWithRetries(downloadRetries, backoff map { + _.next + }, downloadAttempt + 1) + } else { + IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries resolution retries to download GCS file", t)) + } + } + + runDownloadCommand.redeemWith( + recover = maybeRetryForDownloadFailure, + bind = { + case s: DownloadSuccess.type => + IO.pure(s) + case _: RecognizedRetryableDownloadFailure => + downloadWithRetries(downloadRetries, backoff, downloadAttempt+1) + case _: UnrecognizedRetryableDownloadFailure => + downloadWithRetries(downloadRetries, backoff, downloadAttempt+1) + case _ => + downloadWithRetries(downloadRetries, backoff, downloadAttempt+1) + }) + } + /** * Bash to download the GCS file using `gsutil`. */ diff --git a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GetmChecksum.scala b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GetmChecksum.scala index 2a39a6543a3..2ca1bd3d2e3 100644 --- a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GetmChecksum.scala +++ b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GetmChecksum.scala @@ -3,7 +3,7 @@ package drs.localizer.downloaders import cats.syntax.validated._ import cloud.nio.impl.drs.AccessUrl import common.validation.ErrorOr.ErrorOr -import drs.localizer.downloaders.AccessUrlDownloader.Hashes +import drs.localizer.downloaders.BulkAccessUrlDownloader.Hashes import org.apache.commons.codec.binary.Base64.encodeBase64String import org.apache.commons.codec.binary.Hex.decodeHex import org.apache.commons.text.StringEscapeUtils diff --git a/cromwell-drs-localizer/src/test/scala/drs/localizer/DrsLocalizerMainSpec.scala b/cromwell-drs-localizer/src/test/scala/drs/localizer/DrsLocalizerMainSpec.scala index 66799fcc099..52fa4c99330 100644 --- a/cromwell-drs-localizer/src/test/scala/drs/localizer/DrsLocalizerMainSpec.scala +++ b/cromwell-drs-localizer/src/test/scala/drs/localizer/DrsLocalizerMainSpec.scala @@ -3,12 +3,11 @@ package drs.localizer import cats.data.NonEmptyList import cats.effect.{ExitCode, IO} import cats.syntax.validated._ -import cloud.nio.impl.drs.DrsPathResolver.FatalRetryDisposition +import drs.localizer.MockDrsPaths.{fakeAccessUrls, fakeDrsUrlWithGcsResolutionOnly, fakeGoogleUrls} import cloud.nio.impl.drs.{AccessUrl, DrsConfig, DrsCredentials, DrsResolverField, DrsResolverResponse} import common.assertion.CromwellTimeoutSpec import common.validation.ErrorOr.ErrorOr import drs.localizer.MockDrsLocalizerDrsPathResolver.{FakeAccessTokenStrategy, FakeHashes} -import drs.localizer.downloaders.AccessUrlDownloader.Hashes import drs.localizer.downloaders._ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -19,6 +18,28 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat val fakeDownloadLocation = "/root/foo/foo-123.bam" val fakeRequesterPaysId = "fake-billing-project" + val fakeGoogleInput : IO[List[UnresolvedDrsUrl]] = IO(List( + UnresolvedDrsUrl(fakeDrsUrlWithGcsResolutionOnly, "/path/to/nowhere") + )) + + val fakeAccessInput: IO[List[UnresolvedDrsUrl]] = IO(List( + UnresolvedDrsUrl("https://my-fake-access-url.com", "/path/to/somewhereelse") + )) + + val fakeBulkGoogleInput: IO[List[UnresolvedDrsUrl]] = IO(List( + UnresolvedDrsUrl("drs://my-fake-google-url.com", "/path/to/nowhere"), + UnresolvedDrsUrl("drs://my-fake-google-url.com2", "/path/to/nowhere2"), + UnresolvedDrsUrl("drs://my-fake-google-url.com3", "/path/to/nowhere3"), + UnresolvedDrsUrl("drs://my-fake-google-url.com4", "/path/to/nowhere4") + )) + + val fakeBulkAccessInput: IO[List[UnresolvedDrsUrl]] = IO(List( + UnresolvedDrsUrl("drs://my-fake-access-url.com", "/path/to/somewhereelse"), + UnresolvedDrsUrl("drs://my-fake-access-url2.com", "/path/to/somewhereelse2"), + UnresolvedDrsUrl("drs://my-fake-access-url3.com", "/path/to/somewhereelse3"), + UnresolvedDrsUrl("drs://my-fake-access-url4.com", "/path/to/somewhereelse4") + )) + behavior of "DrsLocalizerMain" it should "fail if drs input is not passed" in { @@ -29,264 +50,192 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat DrsLocalizerMain.run(List(MockDrsPaths.fakeDrsUrlWithGcsResolutionOnly)).unsafeRunSync() shouldBe ExitCode.Error } - it should "accept arguments and run successfully without Requester Pays ID" in { - val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithGcsResolutionOnly, fakeDownloadLocation, None) - val expected = GcsUriDownloader( - gcsUrl = "gs://abc/foo-123/abc123", - serviceAccountJson = None, - downloadLoc = fakeDownloadLocation, - requesterPaysProjectIdOption = None) - mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync() shouldBe expected - } - - it should "run successfully with all 3 arguments" in { - val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithGcsResolutionOnly, fakeDownloadLocation, Option(fakeRequesterPaysId)) - val expected = GcsUriDownloader( - gcsUrl = "gs://abc/foo-123/abc123", - serviceAccountJson = None, - downloadLoc = fakeDownloadLocation, - requesterPaysProjectIdOption = Option(fakeRequesterPaysId)) - mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync() shouldBe expected - } - - it should "fail and throw error if the DRS Resolver response does not have gs:// url" in { - val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithoutAnyResolution, fakeDownloadLocation, None) - - the[RuntimeException] thrownBy { - mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync() - } should have message "No access URL nor GCS URI starting with 'gs://' found in the DRS Resolver response!" - } - - it should "resolve to use the correct downloader for an access url" in { - val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlResolutionOnly, fakeDownloadLocation, None) - val expected = AccessUrlDownloader( - accessUrl = AccessUrl(url = "http://abc/def/ghi.bam", headers = None), - downloadLoc = fakeDownloadLocation, - hashes = FakeHashes - ) - mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync() shouldBe expected - } - - it should "resolve to use the correct downloader for an access url when the DRS Resolver response also contains a gs url" in { - val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlAndGcsResolution, fakeDownloadLocation, None) - val expected = AccessUrlDownloader( - accessUrl = AccessUrl(url = "http://abc/def/ghi.bam", headers = None), downloadLoc = fakeDownloadLocation, - hashes = FakeHashes - ) - mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync() shouldBe expected - } - - it should "not retry on access URL download success" in { - var actualAttempts = 0 + it should "tolerate no URLs being provided" in { + val mockDownloadFactory = new DownloaderFactory { + override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): Downloader = { + // This test path should never ask for the Google downloader + throw new RuntimeException("test failure111") + } - val drsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlResolutionOnly, fakeDownloadLocation, None) { - override def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = { - actualAttempts = actualAttempts + 1 - super.resolveAndDownload(downloaderFactory) + override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): Downloader = { + // This test path should never ask for the Bulk downloader + throw new RuntimeException("test failure111") } } - val accessUrlDownloader = IO.pure(new Downloader { - override def download: IO[DownloadResult] = - IO.pure(DownloadSuccess) - }) + val mockdrsLocalizer = new MockDrsLocalizerMain(IO(List()), mockDownloadFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId)) + val downloaders: List[Downloader] = mockdrsLocalizer.buildDownloaders().unsafeRunSync() + downloaders.length shouldBe 0 + } - val downloaderFactory = new DownloaderFactory { - override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = { - accessUrlDownloader + it should "build correct downloader(s) for a single google URL" in { + val mockDownloadFactory = new DownloaderFactory { + override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): Downloader = { + GcsUriDownloader(gcsPath, serviceAccountJsonOption, downloadLoc, requesterPaysProjectOption) } - override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = { - // This test path should never ask for the GCS downloader - throw new RuntimeException("test failure") + override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): Downloader = { + // This test path should never ask for the Bulk downloader + throw new RuntimeException("test failure111") } } - drsLocalizer.resolveAndDownloadWithRetries( - downloadRetries = 3, - checksumRetries = 1, - downloaderFactory = downloaderFactory, - backoff = None - ).unsafeRunSync() shouldBe DownloadSuccess + val mockdrsLocalizer = new MockDrsLocalizerMain(IO(List(fakeGoogleUrls.head._1)), mockDownloadFactory,FakeAccessTokenStrategy, Option(fakeRequesterPaysId)) + val downloaders: List[Downloader] = mockdrsLocalizer.buildDownloaders().unsafeRunSync() + downloaders.length shouldBe 1 - actualAttempts shouldBe 1 - } - - it should "retry an appropriate number of times for regular retryable access URL download failures" in { - var actualAttempts = 0 - - val drsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlResolutionOnly, fakeDownloadLocation, None) { - override def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = { - actualAttempts = actualAttempts + 1 - super.resolveAndDownload(downloaderFactory) - } + val correct = downloaders.head match { + case _: GcsUriDownloader => true + case _ => false } - val accessUrlDownloader = IO.pure(new Downloader { - override def download: IO[DownloadResult] = - IO.pure(RecognizedRetryableDownloadFailure(exitCode = ExitCode(0))) - }) - - val downloaderFactory = new DownloaderFactory { - override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = { - accessUrlDownloader - } + correct shouldBe true + } - override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = { + it should "build correct downloader(s) for a single access URL" in { + val mockDownloadFactory = new DownloaderFactory { + override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): Downloader = { // This test path should never ask for the GCS downloader throw new RuntimeException("test failure") } - } - - assertThrows[Throwable] { - drsLocalizer.resolveAndDownloadWithRetries( - downloadRetries = 3, - checksumRetries = 1, - downloaderFactory = downloaderFactory, - backoff = None - ).unsafeRunSync() - } - - actualAttempts shouldBe 4 // 1 initial attempt + 3 retries = 4 total attempts - } - it should "retry an appropriate number of times for fatal retryable access URL download failures" in { - var actualAttempts = 0 - - val drsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlResolutionOnly, fakeDownloadLocation, None) { - override def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = { - actualAttempts = actualAttempts + 1 - IO.raiseError(new RuntimeException("testing: fatal error") with FatalRetryDisposition) + override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): Downloader = { + BulkAccessUrlDownloader(urlsToDownload) } } - val accessUrlDownloader = IO.pure(new Downloader { - override def download: IO[DownloadResult] = - IO.pure(RecognizedRetryableDownloadFailure(exitCode = ExitCode(0))) - }) + val mockdrsLocalizer = new MockDrsLocalizerMain(IO(List(fakeAccessUrls.head._1)), mockDownloadFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId)) + val downloaders: List[Downloader] = mockdrsLocalizer.buildDownloaders().unsafeRunSync() + downloaders.length shouldBe 1 + + val expected = BulkAccessUrlDownloader( + List(fakeAccessUrls.head._2) + ) + expected shouldEqual downloaders.head + } - val downloaderFactory = new DownloaderFactory { - override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = { - accessUrlDownloader + it should "build correct downloader(s) for multiple google URLs" in { + val mockDownloadFactory = new DownloaderFactory { + override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): Downloader = { + GcsUriDownloader(gcsPath, serviceAccountJsonOption, downloadLoc, requesterPaysProjectOption) } - override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = { + override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): Downloader = { // This test path should never ask for the GCS downloader throw new RuntimeException("test failure") } } - - assertThrows[Throwable] { - drsLocalizer.resolveAndDownloadWithRetries( - downloadRetries = 3, - checksumRetries = 1, - downloaderFactory = downloaderFactory, - backoff = None - ).unsafeRunSync() - } - - actualAttempts shouldBe 1 // 1 and done with a fatal exception - } - - it should "not retry on GCS URI download success" in { - var actualAttempts = 0 - val drsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithGcsResolutionOnly, fakeDownloadLocation, None) { - override def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = { - actualAttempts = actualAttempts + 1 - super.resolveAndDownload(downloaderFactory) - } - } - val gcsUriDownloader = IO.pure(new Downloader { - override def download: IO[DownloadResult] = - IO.pure(DownloadSuccess) + val unresolvedUrls : List[UnresolvedDrsUrl] = fakeGoogleUrls.map(pair => pair._1).toList + val mockdrsLocalizer = new MockDrsLocalizerMain(IO(unresolvedUrls), mockDownloadFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId)) + val downloaders: List[Downloader] = mockdrsLocalizer.buildDownloaders().unsafeRunSync() + downloaders.length shouldBe unresolvedUrls.length + + val countGoogleDownloaders = downloaders.count(downloader => downloader match { + case _: GcsUriDownloader => true + case _ => false }) + // We expect one GCS downloader for each GCS uri provided + countGoogleDownloaders shouldBe downloaders.length + } - val downloaderFactory = new DownloaderFactory { - override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = { - // This test path should never ask for the access URL downloader + it should "build a single bulk downloader for multiple access URLs" in { + val mockDownloadFactory = new DownloaderFactory { + override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): Downloader = { + // This test path should never ask for the GCS downloader throw new RuntimeException("test failure") } - override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = { - gcsUriDownloader + override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): Downloader = { + BulkAccessUrlDownloader(urlsToDownload) } } - - drsLocalizer.resolveAndDownloadWithRetries( - downloadRetries = 3, - checksumRetries = 1, - downloaderFactory = downloaderFactory, - backoff = None).unsafeRunSync() - - actualAttempts shouldBe 1 + val unresolvedUrls: List[UnresolvedDrsUrl] = fakeAccessUrls.map(pair => pair._1).toList + val mockdrsLocalizer = new MockDrsLocalizerMain(IO(unresolvedUrls), mockDownloadFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId)) + val downloaders: List[Downloader] = mockdrsLocalizer.buildDownloaders().unsafeRunSync() + downloaders.length shouldBe 1 + + val countBulkDownloaders = downloaders.count(downloader => downloader match { + case _: BulkAccessUrlDownloader => true + case _ => false + }) + // We expect one total Bulk downloader for all access URIs to share + countBulkDownloaders shouldBe 1 + val expected = BulkAccessUrlDownloader( + fakeAccessUrls.map(pair => pair._2).toList + ) + expected shouldEqual downloaders.head } - it should "retry an appropriate number of times for retryable GCS URI download failures" in { - var actualAttempts = 0 - val drsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithGcsResolutionOnly, fakeDownloadLocation, None) { - override def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = { - actualAttempts = actualAttempts + 1 - super.resolveAndDownload(downloaderFactory) - } - } - val gcsUriDownloader = IO.pure(new Downloader { - override def download: IO[DownloadResult] = - IO.pure(RecognizedRetryableDownloadFailure(exitCode = ExitCode(1))) - }) + it should "build 1 bulk downloader and 5 google downloaders for a mix of URLs" in { + val unresolvedUrls: List[UnresolvedDrsUrl] = fakeAccessUrls.map(pair => pair._1).toList ++ fakeGoogleUrls.map(pair => pair._1).toList + val mockdrsLocalizer = new MockDrsLocalizerMain(IO(unresolvedUrls), DrsLocalizerMain.defaultDownloaderFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId)) + val downloaders: List[Downloader] = mockdrsLocalizer.buildDownloaders().unsafeRunSync() - val downloaderFactory = new DownloaderFactory { - override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = { - // This test path should never ask for the access URL downloader - throw new RuntimeException("test failure") - } + downloaders.length shouldBe 6 - override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = { - gcsUriDownloader - } - } + //we expect a single bulk downloader despite 5 access URLs being provided + val countBulkDownloaders = downloaders.count(downloader => downloader match { + case _: BulkAccessUrlDownloader => true + case _ => false + }) + // We expect one GCS downloader for each GCS uri provided + countBulkDownloaders shouldBe 1 + val countGoogleDownloaders = downloaders.count(downloader => downloader match { + case _: GcsUriDownloader => true + case _ => false + }) + // We expect one GCS downloader for each GCS uri provided + countBulkDownloaders shouldBe 1 + countGoogleDownloaders shouldBe 5 + } - assertThrows[Throwable] { - drsLocalizer.resolveAndDownloadWithRetries( - downloadRetries = 3, - checksumRetries = 1, - downloaderFactory = downloaderFactory, - backoff = None).unsafeRunSync() - } + it should "accept arguments and run successfully without Requester Pays ID" in { + val unresolved = fakeGoogleUrls.head._1 + val mockDrsLocalizer = new MockDrsLocalizerMain(IO(List(unresolved)), DrsLocalizerMain.defaultDownloaderFactory, FakeAccessTokenStrategy, None) + val expected = GcsUriDownloader( + gcsUrl = fakeGoogleUrls.get(unresolved).get.drsResponse.gsUri.get, + serviceAccountJson = None, + downloadLoc = unresolved.downloadDestinationPath, + requesterPaysProjectIdOption = None) + val downloader: Downloader = mockDrsLocalizer.buildDownloaders().unsafeRunSync().head + downloader shouldBe expected + } - actualAttempts shouldBe 4 // 1 initial attempt + 3 retries = 4 total attempts + it should "run successfully with all 3 arguments" in { + val unresolved = fakeGoogleUrls.head._1 + val mockDrsLocalizer = new MockDrsLocalizerMain(IO(List(unresolved)), DrsLocalizerMain.defaultDownloaderFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId)) + val expected = GcsUriDownloader( + gcsUrl = fakeGoogleUrls.get(unresolved).get.drsResponse.gsUri.get, + serviceAccountJson = None, + downloadLoc = unresolved.downloadDestinationPath, + requesterPaysProjectIdOption = Option(fakeRequesterPaysId)) + val downloader: Downloader = mockDrsLocalizer.buildDownloaders().unsafeRunSync().head + downloader shouldBe expected } - it should "retry an appropriate number of times for checksum failures" in { - var actualAttempts = 0 - val drsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlResolutionOnly, fakeDownloadLocation, None) { - override def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = { - actualAttempts = actualAttempts + 1 - super.resolveAndDownload(downloaderFactory) - } - } - val accessUrlDownloader = IO.pure(new Downloader { - override def download: IO[DownloadResult] = - IO.pure(ChecksumFailure) - }) + it should "successfully identify uri types, preferring access" in { + val exampleAccessResponse = DrsResolverResponse(accessUrl = Option(AccessUrl("https://something.com", FakeHashes))) + val exampleGoogleResponse = DrsResolverResponse(gsUri = Option("gs://something")) + val exampleMixedResponse = DrsResolverResponse(accessUrl = Option(AccessUrl("https://something.com", FakeHashes)), gsUri = Option("gs://something")) + DrsLocalizerMain.toValidatedUriType(exampleAccessResponse.accessUrl, exampleAccessResponse.gsUri) shouldBe URIType.ACCESS + DrsLocalizerMain.toValidatedUriType(exampleGoogleResponse.accessUrl, exampleGoogleResponse.gsUri) shouldBe URIType.GCS + DrsLocalizerMain.toValidatedUriType(exampleMixedResponse.accessUrl, exampleMixedResponse.gsUri) shouldBe URIType.ACCESS + } - val downloaderFactory = new DownloaderFactory { - override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = { - accessUrlDownloader - } + it should "throw an exception if the DRS Resolver response is invalid" in { + val badAccessResponse = DrsResolverResponse(accessUrl = Option(AccessUrl("hQQps://something.com", FakeHashes))) + val badGoogleResponse = DrsResolverResponse(gsUri = Option("gQQs://something")) + val emptyResponse = DrsResolverResponse() - override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = { - // This test path should never ask for the GCS URI downloader. - throw new RuntimeException("test failure") - } - } + the[RuntimeException] thrownBy { + DrsLocalizerMain.toValidatedUriType(badAccessResponse.accessUrl, badAccessResponse.gsUri) + } should have message "Resolved Access URL does not start with https://" - assertThrows[Throwable] { - drsLocalizer.resolveAndDownloadWithRetries( - downloadRetries = 3, - checksumRetries = 1, - downloaderFactory = downloaderFactory, - backoff = None).unsafeRunSync() - } + the[RuntimeException] thrownBy { + DrsLocalizerMain.toValidatedUriType(badGoogleResponse.accessUrl, badGoogleResponse.gsUri) + } should have message "Resolved Google URL does not start with gs://" - actualAttempts shouldBe 2 // 1 initial attempt + 1 retry = 2 total attempts + the[RuntimeException] thrownBy { + DrsLocalizerMain.toValidatedUriType(emptyResponse.accessUrl, emptyResponse.gsUri) + } should have message "DRS response did not contain any URLs" } } @@ -295,27 +244,53 @@ object MockDrsPaths { val fakeDrsUrlWithAccessUrlResolutionOnly = "drs://def/bar-456/def456" val fakeDrsUrlWithAccessUrlAndGcsResolution = "drs://ghi/baz-789/ghi789" val fakeDrsUrlWithoutAnyResolution = "drs://foo/bar/no-gcs-path" + + val fakeGoogleUrls: Map[UnresolvedDrsUrl, ResolvedDrsUrl] = Map( + (UnresolvedDrsUrl("drs://abc/foo-123/google/0", "/path/to/google/local0"), ResolvedDrsUrl(DrsResolverResponse(gsUri = Option("gs://some/uri0")), "/path/to/google/local0", URIType.GCS)), + (UnresolvedDrsUrl("drs://abc/foo-123/google/1", "/path/to/google/local1"), ResolvedDrsUrl(DrsResolverResponse(gsUri = Option("gs://some/uri1")), "/path/to/google/local1", URIType.GCS)), + (UnresolvedDrsUrl("drs://abc/foo-123/google/2", "/path/to/google/local2"), ResolvedDrsUrl(DrsResolverResponse(gsUri = Option("gs://some/uri2")), "/path/to/google/local2", URIType.GCS)), + (UnresolvedDrsUrl("drs://abc/foo-123/google/3", "/path/to/google/local3"), ResolvedDrsUrl(DrsResolverResponse(gsUri = Option("gs://some/uri3")), "/path/to/google/local3", URIType.GCS)), + (UnresolvedDrsUrl("drs://abc/foo-123/google/4", "/path/to/google/local4"), ResolvedDrsUrl(DrsResolverResponse(gsUri = Option("gs://some/uri4")), "/path/to/google/local4", URIType.GCS)) + ) + + val fakeAccessUrls: Map[UnresolvedDrsUrl, ResolvedDrsUrl] = Map( + (UnresolvedDrsUrl("drs://abc/foo-123/access/0", "/path/to/access/local0"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/0", FakeHashes))), "/path/to/access/local0", URIType.ACCESS)), + (UnresolvedDrsUrl("drs://abc/foo-123/access/1", "/path/to/access/local1"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/1", FakeHashes))), "/path/to/access/local1", URIType.ACCESS)), + (UnresolvedDrsUrl("drs://abc/foo-123/access/2", "/path/to/access/local2"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/2", FakeHashes))), "/path/to/access/local2", URIType.ACCESS)), + (UnresolvedDrsUrl("drs://abc/foo-123/access/3", "/path/to/access/local3"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/3", FakeHashes))), "/path/to/access/local3", URIType.ACCESS)), + (UnresolvedDrsUrl("drs://abc/foo-123/access/4", "/path/to/access/local4"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/4", FakeHashes))), "/path/to/access/local4", URIType.ACCESS)) + ) } -class MockDrsLocalizerMain(drsUrl: String, - downloadLoc: String, - requesterPaysProjectIdOption: Option[String], +class MockDrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]], + downloaderFactory: DownloaderFactory, + drsCredentials: DrsCredentials, + requesterPaysProjectIdOption: Option[String] ) - extends DrsLocalizerMain(drsUrl, downloadLoc, FakeAccessTokenStrategy, requesterPaysProjectIdOption) { + + extends DrsLocalizerMain(toResolveAndDownload, downloaderFactory, FakeAccessTokenStrategy, requesterPaysProjectIdOption) { override def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = { IO { new MockDrsLocalizerDrsPathResolver(cloud.nio.impl.drs.MockDrsPaths.mockDrsConfig) } } + override def resolveSingleUrl(resolverObject: DrsLocalizerDrsPathResolver, drsUrlToResolve: UnresolvedDrsUrl): IO[ResolvedDrsUrl] = { + IO { + if (!fakeAccessUrls.contains(drsUrlToResolve) && !fakeGoogleUrls.contains(drsUrlToResolve)) { + throw new RuntimeException("Unexpected URI during testing") + } + fakeAccessUrls.getOrElse(drsUrlToResolve, fakeGoogleUrls.getOrElse(drsUrlToResolve, ResolvedDrsUrl(DrsResolverResponse(),"/12/3/", URIType.UNKNOWN))) + } + } } - class MockDrsLocalizerDrsPathResolver(drsConfig: DrsConfig) extends DrsLocalizerDrsPathResolver(drsConfig, FakeAccessTokenStrategy) { override def resolveDrs(drsPath: String, fields: NonEmptyList[DrsResolverField.Value]): IO[DrsResolverResponse] = { + val drsResolverResponse = DrsResolverResponse( size = Option(1234), hashes = FakeHashes diff --git a/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/AccessUrlDownloaderSpec.scala b/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/AccessUrlDownloaderSpec.scala deleted file mode 100644 index df7512dd81a..00000000000 --- a/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/AccessUrlDownloaderSpec.scala +++ /dev/null @@ -1,59 +0,0 @@ -package drs.localizer.downloaders - -import cats.effect.ExitCode -import cats.syntax.validated._ -import cloud.nio.impl.drs.AccessUrl -import common.assertion.CromwellTimeoutSpec -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.prop.TableDrivenPropertyChecks._ - -class AccessUrlDownloaderSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matchers { - it should "return the correct download script for a url-only access URL, no requester pays" in { - val fakeDownloadLocation = "/root/foo/foo-123.bam" - val fakeAccessUrl = "http://abc/def/ghi.bam" - - val downloader = AccessUrlDownloader( - accessUrl = AccessUrl(url = fakeAccessUrl, headers = None), - downloadLoc = fakeDownloadLocation, - hashes = None - ) - - val expected = s"""mkdir -p $$(dirname '$fakeDownloadLocation') && rm -f '$fakeDownloadLocation' && getm --checksum-algorithm 'null' --checksum null --filepath '$fakeDownloadLocation' '$fakeAccessUrl'""".validNel - - downloader.generateDownloadScript shouldBe expected - } - - { - val results = Table( - ("exitCode", "stderr", "download result"), - (0, "", DownloadSuccess), - // In `getm` version 0.0.4 checksum failures currently exit 0. - (0, "oh me oh my: AssertionError: Checksum failed!!!", ChecksumFailure), - // Unrecognized because of non-zero exit code without an HTTP status, despite what looks like a checksum failure. - (1, "oh me oh my: AssertionError: Checksum failed!!!", UnrecognizedRetryableDownloadFailure(ExitCode(1))), - // Unrecognized because of zero exit status with stderr that does not look like a checksum failure. - (0, "what the", UnrecognizedRetryableDownloadFailure(ExitCode(0))), - // Unrecognized because of non-zero exit code without an HTTP status. - (1, " foobar ", UnrecognizedRetryableDownloadFailure(ExitCode(1))), - // Unrecognized because of zero exit status with stderr that does not look like a checksum failure. - (0, """ERROR:getm.cli possibly some words "status_code": 503 words""", UnrecognizedRetryableDownloadFailure(ExitCode(0))), - // Recognized because of non-zero exit status and an HTTP status. - (1, """ERROR:getm.cli possibly some words "status_code": 503 words""", RecognizedRetryableDownloadFailure(ExitCode(1))), - // Recognized because of non-zero exit status and an HTTP status. - (1, """ERROR:getm.cli possibly some words "status_code": 408 more words""", RecognizedRetryableDownloadFailure(ExitCode(1))), - // Recognized and non-retryable because of non-zero exit status and 404 HTTP status. - (1, """ERROR:getm.cli possibly some words "status_code": 404 even more words""", FatalDownloadFailure(ExitCode(1))), - // Unrecognized because of zero exit status and 404 HTTP status. - (0, """ERROR:getm.cli possibly some words "status_code": 404 even more words""", UnrecognizedRetryableDownloadFailure(ExitCode(0))), - ) - - val accessUrlDownloader = AccessUrlDownloader(null, null, null) - - forAll(results) { (exitCode, stderr, expected) => - it should s"produce $expected for exitCode $exitCode and stderr '$stderr'" in { - accessUrlDownloader.toDownloadResult(GetmResult(exitCode, stderr)) shouldBe expected - } - } - } -} diff --git a/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/BulkAccessUrlDownloaderSpec.scala b/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/BulkAccessUrlDownloaderSpec.scala new file mode 100644 index 00000000000..7b96ece8d0a --- /dev/null +++ b/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/BulkAccessUrlDownloaderSpec.scala @@ -0,0 +1,114 @@ +package drs.localizer.downloaders + +import cats.effect.{ExitCode, IO} +import cloud.nio.impl.drs.{AccessUrl, DrsResolverResponse} +import common.assertion.CromwellTimeoutSpec +import org.scalatest.prop.TableDrivenPropertyChecks._ +import drs.localizer.{ResolvedDrsUrl, URIType} + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.file.Path + +class BulkAccessUrlDownloaderSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matchers { + val ex1 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url123", None))), "path/to/local/download/dest", URIType.ACCESS) + val ex2 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url1234", None))), "path/to/local/download/dest2", URIType.ACCESS) + val ex3 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url1235", None))), "path/to/local/download/dest3", URIType.ACCESS) + val emptyList : List[ResolvedDrsUrl] = List() + val oneElement: List[ResolvedDrsUrl] = List(ex1) + val threeElements: List[ResolvedDrsUrl] = List(ex1, ex2, ex3) + + it should "correctly parse a collection of Access Urls into a manifest.json" in { + val expected: String = + s"""|[ + | { + | "url" : "https://my.fake/url123", + | "filepath" : "path/to/local/download/dest" + | }, + | { + | "url" : "https://my.fake/url1234", + | "filepath" : "path/to/local/download/dest2" + | }, + | { + | "url" : "https://my.fake/url1235", + | "filepath" : "path/to/local/download/dest3" + | } + |]""".stripMargin + + val downloader = BulkAccessUrlDownloader(threeElements) + + val filepath: IO[Path] = downloader.generateJsonManifest(threeElements) + val source = scala.io.Source.fromFile(filepath.unsafeRunSync().toString) + val lines = try source.mkString finally source.close() + lines shouldBe expected + } + + it should "properly construct empty JSON array from empty list." in { + val expected: String = + s"""|[ + | + |]""".stripMargin + + val downloader = BulkAccessUrlDownloader(emptyList) + val filepath: IO[Path] = downloader.generateJsonManifest(emptyList) + val source = scala.io.Source.fromFile(filepath.unsafeRunSync().toString) + val lines = try source.mkString finally source.close() + lines shouldBe expected + } + + it should "properly construct JSON array from single element list." in { + val expected: String = + s"""|[ + | { + | "url" : "https://my.fake/url123", + | "filepath" : "path/to/local/download/dest" + | } + |]""".stripMargin + + val downloader = BulkAccessUrlDownloader(oneElement) + val filepath: IO[Path] = downloader.generateJsonManifest(oneElement) + val source = scala.io.Source.fromFile(filepath.unsafeRunSync().toString) + val lines = try source.mkString finally source.close() + lines shouldBe expected + } + + it should "properly construct the invocation command" in { + val downloader = BulkAccessUrlDownloader(oneElement) + val filepath: Path = downloader.generateJsonManifest(threeElements).unsafeRunSync() + val expected = s"""getm --manifest ${filepath.toString}""" + downloader.generateGetmCommand(filepath) shouldBe expected + } + + { + val results = Table( + ("exitCode", "stderr", "download result"), + (0, "", DownloadSuccess), + // In `getm` version 0.0.4 checksum failures currently exit 0. + (0, "oh me oh my: AssertionError: Checksum failed!!!", ChecksumFailure), + // Unrecognized because of non-zero exit code without an HTTP status, despite what looks like a checksum failure. + (1, "oh me oh my: AssertionError: Checksum failed!!!", UnrecognizedRetryableDownloadFailure(ExitCode(1))), + // Unrecognized because of zero exit status with stderr that does not look like a checksum failure. + (0, "what the", UnrecognizedRetryableDownloadFailure(ExitCode(0))), + // Unrecognized because of non-zero exit code without an HTTP status. + (1, " foobar ", UnrecognizedRetryableDownloadFailure(ExitCode(1))), + // Unrecognized because of zero exit status with stderr that does not look like a checksum failure. + (0, """ERROR:getm.cli possibly some words "status_code": 503 words""", UnrecognizedRetryableDownloadFailure(ExitCode(0))), + // Recognized because of non-zero exit status and an HTTP status. + (1, """ERROR:getm.cli possibly some words "status_code": 503 words""", RecognizedRetryableDownloadFailure(ExitCode(1))), + // Recognized because of non-zero exit status and an HTTP status. + (1, """ERROR:getm.cli possibly some words "status_code": 408 more words""", RecognizedRetryableDownloadFailure(ExitCode(1))), + // Recognized and non-retryable because of non-zero exit status and 404 HTTP status. + (1, """ERROR:getm.cli possibly some words "status_code": 404 even more words""", FatalDownloadFailure(ExitCode(1))), + // Unrecognized because of zero exit status and 404 HTTP status. + (0, """ERROR:getm.cli possibly some words "status_code": 404 even more words""", UnrecognizedRetryableDownloadFailure(ExitCode(0))), + ) + val bulkDownloader = BulkAccessUrlDownloader(null) + + forAll(results) { (exitCode, stderr, expected) => + it should s"produce $expected for exitCode $exitCode and stderr '$stderr'" in { + bulkDownloader.toDownloadResult(GetmResult(exitCode, stderr)) shouldBe expected + } + } + } +} diff --git a/runConfigurations/Repo template_ Cromwell DRS Localizer.run.xml b/runConfigurations/Repo template_ Cromwell DRS Localizer.run.xml index 91b36d9277e..12f01e0179e 100644 --- a/runConfigurations/Repo template_ Cromwell DRS Localizer.run.xml +++ b/runConfigurations/Repo template_ Cromwell DRS Localizer.run.xml @@ -5,7 +5,7 @@