From 14e4c938f41a889e3ce23d9d86e62c19c6718663 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 15 Mar 2024 11:51:15 +0100 Subject: [PATCH 1/4] Support auto-setting AWS credentials for storage options --- dask_deltatable/core.py | 15 ++++++--- dask_deltatable/utils.py | 70 +++++++++++++++++++++++++++++++++++++++- dask_deltatable/write.py | 13 +++++++- tests/test_acceptance.py | 4 ++- 4 files changed, 95 insertions(+), 7 deletions(-) diff --git a/dask_deltatable/core.py b/dask_deltatable/core.py index b89e6d8..1fdd955 100644 --- a/dask_deltatable/core.py +++ b/dask_deltatable/core.py @@ -18,7 +18,7 @@ from pyarrow import dataset as pa_ds from .types import Filters -from .utils import get_partition_filters +from .utils import get_partition_filters, maybe_set_aws_credentials if Version(pa.__version__) >= Version("10.0.0"): filters_to_expression = pq.filters_to_expression @@ -94,6 +94,9 @@ def _read_from_filesystem( """ Reads the list of parquet files in parallel """ + storage_options = maybe_set_aws_credentials(path, storage_options) # type: ignore + delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) # type: ignore + fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options) dt = DeltaTable( table_uri=path, version=version, storage_options=delta_storage_options @@ -116,12 +119,14 @@ def _read_from_filesystem( if columns: meta = meta[columns] + kws = dict(meta=meta, label="read-delta-table") + if not dd._dask_expr_enabled(): + # Setting token not supported in dask-expr + kws["token"] = tokenize(path, fs_token, **kwargs) # type: ignore return dd.from_map( partial(_read_delta_partition, fs=fs, columns=columns, schema=schema, **kwargs), pq_files, - meta=meta, - label="read-delta-table", - token=tokenize(path, fs_token, **kwargs), + **kws, ) @@ -270,6 +275,8 @@ def read_deltalake( else: if path is None: raise ValueError("Please Provide Delta Table path") + + delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) # type: ignore resultdf = _read_from_filesystem( path=path, version=version, diff --git a/dask_deltatable/utils.py b/dask_deltatable/utils.py index 3901f63..dabb6b4 100644 --- a/dask_deltatable/utils.py +++ b/dask_deltatable/utils.py @@ -1,10 +1,78 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from .types import Filter, Filters +def get_bucket_region(path: str): + import boto3 + + if not path.startswith("s3://"): + raise ValueError(f"'{path}' is not an S3 path") + bucket = path.replace("s3://", "").split("/")[0] + resp = boto3.client("s3").get_bucket_location(Bucket=bucket) + # Buckets in region 'us-east-1' results in None, b/c why not. + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_bucket_location.html#S3.Client.get_bucket_location + return resp["LocationConstraint"] or "us-east-1" + + +def maybe_set_aws_credentials(path: Any, options: dict[str, Any]) -> dict[str, Any]: + """ + Maybe set AWS credentials into ``options`` if existing AWS specific keys + not found in it and path is s3:// format. + + Parameters + ---------- + path : Any + If it's a string, we'll check if it starts with 's3://' then determine bucket + region if the AWS credentials should be set. + options : dict[str, Any] + Options, any kwargs to be supplied to things like S3FileSystem or similar + that may accept AWS credentials set. A copy is made and returned if modified. + + Returns + ------- + dict + Either the original options if not modified, or a copied and updated options + with AWS credentials inserted. + """ + + is_s3_path = getattr(path, "startswith", lambda _: False)("s3://") + if not is_s3_path: + return options + + # Avoid overwriting already provided credentials + keys = ("AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY", "access_key", "secret_key") + if not any(k in (options or ()) for k in keys): + # defers installing boto3 upfront, xref _read_from_catalog + import boto3 + + session = boto3.session.Session() + credentials = session.get_credentials() + if credentials is None: + return options + region = get_bucket_region(path) + + options = (options or {}).copy() + options.update( + # Capitalized is used in delta specific API and lowercase is for S3FileSystem + dict( + # TODO: w/o this, we need to configure a LockClient which seems to require dynamodb. + AWS_S3_ALLOW_UNSAFE_RENAME="true", + AWS_SECRET_ACCESS_KEY=credentials.secret_key, + AWS_ACCESS_KEY_ID=credentials.access_key, + AWS_SESSION_TOKEN=credentials.token, + AWS_REGION=region, + secret_key=credentials.secret_key, + access_key=credentials.access_key, + token=credentials.token, + region=region, + ) + ) + return options + + def get_partition_filters( partition_columns: list[str], filters: Filters ) -> list[list[Filter]] | None: diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py index 75eca45..add512d 100644 --- a/dask_deltatable/write.py +++ b/dask_deltatable/write.py @@ -15,8 +15,15 @@ from dask.dataframe.core import Scalar from dask.highlevelgraph import HighLevelGraph from deltalake import DeltaTable + +try: + from deltalake.writer import MAX_SUPPORTED_WRITER_VERSION # type: ignore +except ImportError: + from deltalake.writer import ( + MAX_SUPPORTED_PYARROW_WRITER_VERSION as MAX_SUPPORTED_WRITER_VERSION, + ) + from deltalake.writer import ( - MAX_SUPPORTED_WRITER_VERSION, PYARROW_MAJOR_VERSION, AddAction, DeltaJSONEncoder, @@ -31,6 +38,7 @@ from toolz.itertoolz import pluck from ._schema import pyarrow_to_deltalake, validate_compatible +from .utils import maybe_set_aws_credentials def to_deltalake( @@ -123,6 +131,7 @@ def to_deltalake( ------- dask.Scalar """ + storage_options = maybe_set_aws_credentials(table_or_uri, storage_options) # type: ignore table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) # We need to write against the latest table version @@ -136,6 +145,7 @@ def to_deltalake( storage_options = table._storage_options or {} storage_options.update(storage_options or {}) + storage_options = maybe_set_aws_credentials(table_uri, storage_options) filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) if isinstance(partition_by, str): @@ -253,6 +263,7 @@ def _commit( schema = validate_compatible(schemas) assert schema if table is None: + storage_options = maybe_set_aws_credentials(table_uri, storage_options) write_deltalake_pyarrow( table_uri, schema, diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index 17e99ff..5d1b57c 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -50,7 +50,9 @@ def test_reader_all_primitive_types(): # Dask and delta go through different parquet parsers which read the # timestamp differently. This is likely a bug in arrow but the delta result # is "more correct". - expected_ddf["timestamp"] = expected_ddf["timestamp"].astype("datetime64[us]") + expected_ddf["timestamp"] = ( + expected_ddf["timestamp"].astype("datetime64[us]").dt.tz_localize("UTC") + ) assert_eq(actual_ddf, expected_ddf) From 5261ee012c36c9453e1505d0854d52ee3f92dac5 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Mon, 18 Mar 2024 10:38:11 +0100 Subject: [PATCH 2/4] Tests for get_bucket_region and maybe_set_aws_credentials --- tests/test_utils.py | 82 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index d8b49dd..bdfbea2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,15 @@ from __future__ import annotations +import pathlib +import unittest.mock as mock + import pytest -from dask_deltatable.utils import get_partition_filters +from dask_deltatable.utils import ( + get_bucket_region, + get_partition_filters, + maybe_set_aws_credentials, +) @pytest.mark.parametrize( @@ -31,3 +38,76 @@ def test_partition_filters(cols, filters, expected): # make sure it works with additional level of wrapping res = get_partition_filters(cols, filters) assert res == expected + + +@mock.patch("boto3.session.Session") +@mock.patch("dask_deltatable.utils.get_bucket_region") +@pytest.mark.parametrize( + "options", + ( + None, + dict(), + dict(AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar"), + dict(access_key="foo", secret_key="bar"), + ), +) +@pytest.mark.parametrize("path", ("s3://path", "/another/path", pathlib.Path("."))) +def test_maybe_set_aws_credentials( + mocked_get_bucket_region, + mocked_session, + options, + path, +): + mock_creds = mock.MagicMock() + type(mock_creds).token = mock.PropertyMock(return_value="token") + type(mock_creds).access_key = mock.PropertyMock(return_value="access-key") + type(mock_creds).secret_key = mock.PropertyMock(return_value="secret-key") + + def mock_get_credentials(): + return mock_creds + + session = mocked_session.return_value + session.get_credentials.side_effect = mock_get_credentials + + mocked_get_bucket_region.return_value = "foo-region" + + opts = maybe_set_aws_credentials(path, options) + + if options and not any(k in options for k in ("AWS_ACCESS_KEY_ID", "access_key")): + assert opts["AWS_ACCESS_KEY_ID"] == "access-key" + assert opts["AWS_SECRET_ACCESS_KEY"] == "secret-key" + assert opts["AWS_SESSION_TOKEN"] == "token" + assert opts["AWS_REGION"] == "foo-region" + + assert opts["access_key"] == "access-key" + assert opts["secret_key"] == "secret-key" + assert opts["token"] == "token" + assert opts["region"] == "foo-region" + + # Did not alter input options if credentials were supplied by user + elif options: + assert options == opts + + +@mock.patch("boto3.client") +@pytest.mark.parametrize("location", (None, "region-foo")) +@pytest.mark.parametrize( + "path,bucket", + (("s3://foo/bar", "foo"), ("s3://fizzbuzz", "fizzbuzz"), ("/not/s3", None)), +) +def test_get_bucket_region(mock_client, location, path, bucket): + mock_client = mock_client.return_value + mock_client.get_bucket_location.return_value = {"LocationConstraint": location} + + if not path.startswith("s3://"): + with pytest.raises(ValueError, match="is not an S3 path"): + get_bucket_region(path) + return + + region = get_bucket_region(path) + + # AWS returns None if bucket located in us-east-1... + location = location if location else "us-east-1" + assert region == location + + mock_client.get_bucket_location.assert_has_calls([mock.call(Bucket=bucket)]) From b58712ecfb1bf57cb29874e5af8bb32097acd6f3 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Mon, 18 Mar 2024 11:07:40 +0100 Subject: [PATCH 3/4] Tests to ensure to/read_deltalake call maybe_set_aws_credentails --- dask_deltatable/core.py | 12 +++++++----- dask_deltatable/write.py | 8 ++++---- tests/test_acceptance.py | 10 ++++++++++ tests/test_write.py | 13 +++++++++++++ 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/dask_deltatable/core.py b/dask_deltatable/core.py index 1fdd955..6da4404 100644 --- a/dask_deltatable/core.py +++ b/dask_deltatable/core.py @@ -17,8 +17,8 @@ from packaging.version import Version from pyarrow import dataset as pa_ds +from . import utils from .types import Filters -from .utils import get_partition_filters, maybe_set_aws_credentials if Version(pa.__version__) >= Version("10.0.0"): filters_to_expression = pq.filters_to_expression @@ -44,7 +44,9 @@ def _get_pq_files(dt: DeltaTable, filter: Filters = None) -> list[str]: list[str] List of files matching optional filter. """ - partition_filters = get_partition_filters(dt.metadata().partition_columns, filter) + partition_filters = utils.get_partition_filters( + dt.metadata().partition_columns, filter + ) if not partition_filters: # can't filter return sorted(dt.file_uris()) @@ -94,8 +96,8 @@ def _read_from_filesystem( """ Reads the list of parquet files in parallel """ - storage_options = maybe_set_aws_credentials(path, storage_options) # type: ignore - delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) # type: ignore + storage_options = utils.maybe_set_aws_credentials(path, storage_options) # type: ignore + delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options) dt = DeltaTable( @@ -276,7 +278,7 @@ def read_deltalake( if path is None: raise ValueError("Please Provide Delta Table path") - delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) # type: ignore + delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore resultdf = _read_from_filesystem( path=path, version=version, diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py index add512d..68b8754 100644 --- a/dask_deltatable/write.py +++ b/dask_deltatable/write.py @@ -37,8 +37,8 @@ ) from toolz.itertoolz import pluck +from . import utils from ._schema import pyarrow_to_deltalake, validate_compatible -from .utils import maybe_set_aws_credentials def to_deltalake( @@ -131,7 +131,7 @@ def to_deltalake( ------- dask.Scalar """ - storage_options = maybe_set_aws_credentials(table_or_uri, storage_options) # type: ignore + storage_options = utils.maybe_set_aws_credentials(table_or_uri, storage_options) # type: ignore table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) # We need to write against the latest table version @@ -145,7 +145,7 @@ def to_deltalake( storage_options = table._storage_options or {} storage_options.update(storage_options or {}) - storage_options = maybe_set_aws_credentials(table_uri, storage_options) + storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options) filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) if isinstance(partition_by, str): @@ -263,7 +263,7 @@ def _commit( schema = validate_compatible(schemas) assert schema if table is None: - storage_options = maybe_set_aws_credentials(table_uri, storage_options) + storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options) write_deltalake_pyarrow( table_uri, schema, diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index 5d1b57c..0394277 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -14,6 +14,7 @@ import os import shutil +import unittest.mock as mock from urllib.request import urlretrieve import dask.dataframe as dd @@ -42,6 +43,15 @@ def download_data(): assert os.path.exists(DATA_DIR) +@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials") +def test_reader_check_aws_credentials(maybe_set_aws_credentials): + # The full functionality of maybe_set_aws_credentials tests in test_utils + # we only need to ensure it's called here when reading with a str path + maybe_set_aws_credentials.return_value = dict() + ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta") + maybe_set_aws_credentials.assert_called() + + def test_reader_all_primitive_types(): actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta") expected_ddf = dd.read_parquet( diff --git a/tests/test_write.py b/tests/test_write.py index 8afefbe..3feae43 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import unittest.mock as mock import dask.dataframe as dd import pandas as pd @@ -61,6 +62,18 @@ def test_roundtrip(tmpdir, with_index, freq, partition_freq): assert_eq(ddf_read, ddf_dask) +@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials") +def test_writer_check_aws_credentials(maybe_set_aws_credentials, tmpdir): + # The full functionality of maybe_set_aws_credentials tests in test_utils + # we only need to ensure it's called here when writing with a str path + maybe_set_aws_credentials.return_value = dict() + + df = pd.DataFrame({"col1": range(10)}) + ddf = dd.from_pandas(df, npartitions=2) + to_deltalake(str(tmpdir), ddf) + maybe_set_aws_credentials.assert_called() + + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_datetime(tmpdir, unit): """Ensure we can write datetime with different resolutions, From 44b9a3b7fdc92bd29972e582647184189fbae8cd Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Mon, 18 Mar 2024 12:05:44 +0100 Subject: [PATCH 4/4] Skipif boto3 not installed --- tests/test_utils.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index bdfbea2..4743f23 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -40,7 +40,6 @@ def test_partition_filters(cols, filters, expected): assert res == expected -@mock.patch("boto3.session.Session") @mock.patch("dask_deltatable.utils.get_bucket_region") @pytest.mark.parametrize( "options", @@ -54,10 +53,13 @@ def test_partition_filters(cols, filters, expected): @pytest.mark.parametrize("path", ("s3://path", "/another/path", pathlib.Path("."))) def test_maybe_set_aws_credentials( mocked_get_bucket_region, - mocked_session, options, path, ): + pytest.importorskip("boto3") + + mocked_get_bucket_region.return_value = "foo-region" + mock_creds = mock.MagicMock() type(mock_creds).token = mock.PropertyMock(return_value="token") type(mock_creds).access_key = mock.PropertyMock(return_value="access-key") @@ -66,12 +68,11 @@ def test_maybe_set_aws_credentials( def mock_get_credentials(): return mock_creds - session = mocked_session.return_value - session.get_credentials.side_effect = mock_get_credentials + with mock.patch("boto3.session.Session") as mocked_session: + session = mocked_session.return_value + session.get_credentials.side_effect = mock_get_credentials - mocked_get_bucket_region.return_value = "foo-region" - - opts = maybe_set_aws_credentials(path, options) + opts = maybe_set_aws_credentials(path, options) if options and not any(k in options for k in ("AWS_ACCESS_KEY_ID", "access_key")): assert opts["AWS_ACCESS_KEY_ID"] == "access-key" @@ -89,22 +90,24 @@ def mock_get_credentials(): assert options == opts -@mock.patch("boto3.client") @pytest.mark.parametrize("location", (None, "region-foo")) @pytest.mark.parametrize( "path,bucket", (("s3://foo/bar", "foo"), ("s3://fizzbuzz", "fizzbuzz"), ("/not/s3", None)), ) -def test_get_bucket_region(mock_client, location, path, bucket): - mock_client = mock_client.return_value - mock_client.get_bucket_location.return_value = {"LocationConstraint": location} +def test_get_bucket_region(location, path, bucket): + pytest.importorskip("boto3") + + with mock.patch("boto3.client") as mock_client: + mock_client = mock_client.return_value + mock_client.get_bucket_location.return_value = {"LocationConstraint": location} - if not path.startswith("s3://"): - with pytest.raises(ValueError, match="is not an S3 path"): - get_bucket_region(path) - return + if not path.startswith("s3://"): + with pytest.raises(ValueError, match="is not an S3 path"): + get_bucket_region(path) + return - region = get_bucket_region(path) + region = get_bucket_region(path) # AWS returns None if bucket located in us-east-1... location = location if location else "us-east-1"