Skip to content

Commit

Permalink
GTC-2618 Allow appends of GPKG layers to table
Browse files Browse the repository at this point in the history
  • Loading branch information
manukala6 committed May 20, 2024
1 parent fcbf56a commit 218b11f
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 9 deletions.
6 changes: 2 additions & 4 deletions app/models/pydantic/creation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,8 @@ class VectorSourceCreationOptions(StrictBaseModel):
def validate_source_uri(cls, v, values, **kwargs):
if values.get("source_driver") == VectorDrivers.csv:
assert len(v) >= 1, "CSV sources require at least one input file"
else:
assert (
len(v) == 1
), "Non-CSV vector sources require one and only one input file"
elif values.get("source_driver") in [VectorDrivers.esrijson, VectorDrivers.shp, VectorDrivers.geojson_seq, VectorDrivers.geojson]:
assert (len(v) == 1), "GeoJSON and ESRI Shapefile vector sources require one and only one input file"
return v


Expand Down
17 changes: 15 additions & 2 deletions app/models/pydantic/versions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Optional, Tuple, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator

from ..enum.creation_options import VectorDrivers
from ..enum.versions import VersionStatus
from .base import BaseRecord, StrictBaseModel
from .creation_options import SourceCreationOptions
Expand Down Expand Up @@ -58,7 +59,19 @@ class VersionUpdateIn(StrictBaseModel):


class VersionAppendIn(StrictBaseModel):
source_uri: List[str]
source_driver: VectorDrivers = Field(
..., description="Driver of source file. Must be an OGR driver"
)
layers: Optional[List[str]] = Field(
None,
description="List of layer names to append to version. "
"If not set, all layers in source_uri will be appended.",
)

@validator("source_driver")
def validate_source_driver(cls, v, values, **kwargs):
assert values.get("source_driver") in [VectorDrivers.csv, VectorDrivers.gpkg, VectorDrivers.file_gdb], "Appends for {} are not supported".format(values.get("source_driver"))
return v


class VersionResponse(Response):
Expand Down
15 changes: 14 additions & 1 deletion app/routes/datasets/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,26 @@ async def append_to_version(
# For the background task, we only need the new source uri from the request
input_data = {"creation_options": deepcopy(default_asset.creation_options)}
input_data["creation_options"]["source_uri"] = request.source_uri
# If there are no existing layers, we can just use the new layers
if input_data["creation_options"].get("layers") is None:
input_data["creation_options"]["layers"] = request.layers
# Otherwise we append the new layers to the existing ones
elif request.layers is not None:
input_data["creation_options"]["layers"] += request.layers
else:
input_data["creation_options"]["layers"] = None
background_tasks.add_task(
append_default_asset, dataset, version, input_data, default_asset.asset_id
)

# We now want to append the new uris to the existing ones and update the asset
update_data = {"creation_options": deepcopy(default_asset.creation_options)}
update_data["creation_options"]["source_uri"] += request.source_uri
update_data["creation_options"]["source_uri"] += request.source_uri # ERROR: only one source_uri is allowed
if input_data["creation_options"].get("layers") is not None:
if update_data["creation_options"]["layers"] is not None:
update_data["creation_options"]["layers"] += request.layers
else:
update_data["creation_options"]["layers"] = request.layers
await assets.update_asset(default_asset.asset_id, **update_data)

version_orm: ORMVersion = await versions.get_version(dataset, version)
Expand Down
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
SHP_NAME = "test.shp.zip"
SHP_PATH = os.path.join(os.path.dirname(__file__), "fixtures", SHP_NAME)

GPKG_NAME = "test.gpkg.zip"
GPKG_PATH = os.path.join(os.path.dirname(__file__), "fixtures", GPKG_NAME)

BUCKET = "test-bucket"
PORT = 9000

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
GEOJSON_NAME2,
GEOJSON_PATH,
GEOJSON_PATH2,
GPKG_NAME,
GPKG_PATH,
PORT,
SHP_NAME,
SHP_PATH,
Expand Down Expand Up @@ -308,6 +310,7 @@ def copy_fixtures():
s3_client.upload_file(CSV2_PATH, BUCKET, CSV2_NAME)
s3_client.upload_file(TSV_PATH, BUCKET, TSV_NAME)
s3_client.upload_file(SHP_PATH, BUCKET, SHP_NAME)
s3_client.upload_file(GPKG_PATH, BUCKET, GPKG_NAME)
s3_client.upload_file(APPEND_TSV_PATH, BUCKET, APPEND_TSV_NAME)

# upload a separate for each row so we can test running large numbers of sources in parallel
Expand Down
Binary file added tests/fixtures/test.gpkg.zip
Binary file not shown.
14 changes: 12 additions & 2 deletions tests/routes/datasets/test_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from app.settings.globals import S3_ENTRYPOINT_URL
from app.utils.aws import get_s3_client
from tests import BUCKET, DATA_LAKE_BUCKET, SHP_NAME
from tests import BUCKET, DATA_LAKE_BUCKET, SHP_NAME, GPKG_NAME
from tests.conftest import FAKE_INT_DATA_PARAMS
from tests.tasks import MockCloudfrontClient
from tests.utils import (
Expand Down Expand Up @@ -314,7 +314,8 @@ async def test_invalid_source_uri(async_client: AsyncClient):
)

# Create a version with a valid source_uri so we have something to append to
version_payload["creation_options"]["source_uri"] = [f"s3://{BUCKET}/{SHP_NAME}"]
version_payload["creation_options"]["source_uri"] = [f"s3://{BUCKET}/{GPKG_NAME}"]
version_payload["creation_options"]["layers"] = ["layer1"]
await create_default_asset(
dataset,
version,
Expand Down Expand Up @@ -348,6 +349,15 @@ async def test_invalid_source_uri(async_client: AsyncClient):
== f"Version with name {dataset}.{bad_version} does not exist"
)

# Test appending to a version with missing layers
response = await async_client.post(
f"/dataset/{dataset}/{version}/append", json={"source_uri": f"s3://{BUCKET}/{GPKG_NAME}", "layers": ["layer3"]}
)
assert response.status_code == 400
assert response.json()["status"] == "failed"




@pytest.mark.asyncio
async def test_put_latest(async_client: AsyncClient):
Expand Down

0 comments on commit 218b11f

Please sign in to comment.