From 5e6527c1bccfe9a1ffe230789fa9bf6c6be62ea1 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 8 May 2023 22:08:25 +0200 Subject: [PATCH 1/8] Convert information into dict --- .../Under the Hood-20230508-222313.yaml | 6 + dbt/adapters/spark/impl.py | 177 +++++++++++------- dbt/adapters/spark/relation.py | 6 +- 3 files changed, 119 insertions(+), 70 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20230508-222313.yaml diff --git a/.changes/unreleased/Under the Hood-20230508-222313.yaml b/.changes/unreleased/Under the Hood-20230508-222313.yaml new file mode 100644 index 000000000..29a628119 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230508-222313.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Convert information into dict +time: 2023-05-08T22:23:13.704302+02:00 +custom: + Author: Fokko + Issue: "751" diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 37de188c5..15d8fbe1d 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -58,6 +58,14 @@ class SparkConfig(AdapterConfig): merge_update_columns: Optional[str] = None +@dataclass(frozen=True) +class RelationInfo: + table_schema: str + table_name: str + columns: List[Tuple[str, str]] + properties: Dict[str, str] + + class SparkAdapter(SQLAdapter): COLUMN_NAMES = ( "table_database", @@ -79,9 +87,7 @@ class SparkAdapter(SQLAdapter): "stats:rows:description", "stats:rows:include", ) - INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) - INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE) - INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE) + INFORMATION_COLUMN_REGEX = re.compile(r" \|-- (.*): (.*) \(nullable = (.*)\)") HUDI_METADATA_COLUMNS = [ "_hoodie_commit_time", "_hoodie_commit_seqno", @@ -91,7 +97,6 @@ class SparkAdapter(SQLAdapter): ] Relation: TypeAlias = SparkRelation - RelationInfo = Tuple[str, str, str] Column: TypeAlias = SparkColumn ConnectionManager: TypeAlias = SparkConnectionManager AdapterSpecificConfigs: TypeAlias = SparkConfig @@ -139,13 +144,43 @@ def add_schema_to_cache(self, schema) -> str: def _get_relation_information(self, row: agate.Row) -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: - _schema, name, _, information = row + # Example lines: + # Database: dbt_schema + # Table: names + # Owner: fokkodriesprong + # Created Time: Mon May 08 18:06:47 CEST 2023 + # Last Access: UNKNOWN + # Created By: Spark 3.3.2 + # Type: MANAGED + # Provider: hive + # Table Properties: [transient_lastDdlTime=1683562007] + # Statistics: 16 bytes + # Schema: root + # |-- idx: integer (nullable = false) + # |-- name: string (nullable = false) + table_properties = {} + columns = [] + _schema, name, _, information_blob = row + for line in information_blob.split("\n"): + if line: + if line.startswith(" |--"): + # A column + match = self.INFORMATION_COLUMN_REGEX.match(line) + if match: + columns.append((match[1], match[2])) + else: + logger.warning(f"Could not parse: {line}") + else: + # A property + parts = line.split(": ", maxsplit=2) + table_properties[parts[0]] = parts[1] + except ValueError: raise dbt.exceptions.DbtRuntimeError( f'Invalid value from "show tables extended ...", got {len(row)} values, expected 4' ) - return _schema, name, information + return RelationInfo(_schema, name, columns, table_properties) def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo: """Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement""" @@ -165,13 +200,42 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn logger.debug(f"Error while retrieving information about {table_name}: {e.msg}") table_results = AttrDict() - information = "" - for info_row in table_results: + # idx int + # name string + # + # # Partitioning + # Not partitioned + # + # # Metadata Columns + # _spec_id int + # _partition struct<> + # _file string + # _pos bigint + # _deleted boolean + # + # # Detailed Table Information + # Name sandbox.dbt_tabular3.names + # Location s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb + # Provider iceberg + + # Wrap it in an iter, so we continue reading the properties from where we stopped reading columns + table_results_itr = iter(table_results) + + # First the columns + columns = [] + for info_row in table_results_itr: + if info_row[0] == "": + break + columns.append((info_row[0], info_row[1])) + + # Next all the properties + table_properties = {} + for info_row in table_results_itr: info_type, info_value, _ = info_row - if not info_type.startswith("#"): - information += f"{info_type}: {info_value}\n" + if not info_type.startswith("#") and info_type != "": + table_properties[info_type] = info_value - return _schema, name, information + return RelationInfo(_schema, name, columns, table_properties) def _build_spark_relation_list( self, @@ -179,27 +243,31 @@ def _build_spark_relation_list( relation_info_func: Callable[[agate.Row], RelationInfo], ) -> List[BaseRelation]: """Aggregate relations with format metadata included.""" - relations = [] + relations: List[BaseRelation] = [] for row in row_list: - _schema, name, information = relation_info_func(row) + relation = relation_info_func(row) rel_type: RelationType = ( - RelationType.View if "Type: VIEW" in information else RelationType.Table + RelationType.View + if relation.properties.get("type") == "VIEW" + else RelationType.Table ) - is_delta: bool = "Provider: delta" in information - is_hudi: bool = "Provider: hudi" in information - is_iceberg: bool = "Provider: iceberg" in information - - relation: BaseRelation = self.Relation.create( # type: ignore - schema=_schema, - identifier=name, - type=rel_type, - information=information, - is_delta=is_delta, - is_iceberg=is_iceberg, - is_hudi=is_hudi, + is_delta: bool = relation.properties.get("provider") == "delta" + is_hudi: bool = relation.properties.get("provider") == "hudi" + is_iceberg: bool = relation.properties.get("provider") == "iceberg" + + relations.append( + self.Relation.create( + schema=relation.table_schema, + identifier=relation.table_name, + type=rel_type, + is_delta=is_delta, + is_iceberg=is_iceberg, + is_hudi=is_hudi, + columns=relation.columns, + properties=relation.properties, + ) ) - relations.append(relation) return relations @@ -250,19 +318,10 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ return super().get_relation(database, schema, identifier) def parse_describe_extended( - self, relation: BaseRelation, raw_rows: AttrDict + self, relation: SparkRelation, raw_rows: AttrDict ) -> List[SparkColumn]: # Convert the Row to a dict - dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows] - # Find the separator between the rows and the metadata provided - # by the DESCRIBE TABLE EXTENDED statement - pos = self.find_table_information_separator(dict_rows) - - # Remove rows that start with a hash, they are comments - rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] - metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} - - raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) + raw_table_stats = relation.properties.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) return [ SparkColumn( @@ -270,31 +329,22 @@ def parse_describe_extended( table_schema=relation.schema, table_name=relation.name, table_type=relation.type, - table_owner=str(metadata.get(KEY_TABLE_OWNER)), + table_owner=relation.properties.get(KEY_TABLE_OWNER, ""), table_stats=table_stats, - column=column["col_name"], + column=column_name, column_index=idx, - dtype=column["data_type"], + dtype=column_type, ) - for idx, column in enumerate(rows) + for idx, (column_name, column_type) in enumerate(relation.columns) ] - @staticmethod - def find_table_information_separator(rows: List[dict]) -> int: - pos = 0 - for row in rows: - if not row["col_name"] or row["col_name"].startswith("#"): - break - pos += 1 - return pos - def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: columns = [] try: rows: AttrDict = self.execute_macro( GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} ) - columns = self.parse_describe_extended(relation, rows) + columns = self.parse_describe_extended(relation, rows) # type: ignore except dbt.exceptions.DbtRuntimeError as e: # spark would throw error when table doesn't exist, where other # CDW would just return and empty list, normalizing the behavior here @@ -309,20 +359,13 @@ def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] return columns - def parse_columns_from_information(self, relation: BaseRelation) -> List[SparkColumn]: - if hasattr(relation, "information"): - information = relation.information or "" - else: - information = "" - owner_match = re.findall(self.INFORMATION_OWNER_REGEX, information) - owner = owner_match[0] if owner_match else None - matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, information) + def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]: + owner = relation.properties.get(KEY_TABLE_OWNER, "") columns = [] - stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, information) - raw_table_stats = stats_match[0] if stats_match else None - table_stats = SparkColumn.convert_table_stats(raw_table_stats) - for match_num, match in enumerate(matches): - column_name, column_type, nullable = match.groups() + table_stats = SparkColumn.convert_table_stats( + relation.properties.get(KEY_TABLE_STATISTICS) + ) + for match_num, (column_name, column_type) in enumerate(relation.columns): column = SparkColumn( table_database=None, table_schema=relation.schema, @@ -338,7 +381,7 @@ def parse_columns_from_information(self, relation: BaseRelation) -> List[SparkCo return columns def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, Any]]: - columns = self.parse_columns_from_information(relation) + columns = self.parse_columns_from_information(relation) # type: ignore for column in columns: # convert SparkColumns into catalog dicts @@ -362,7 +405,7 @@ def get_catalog(self, manifest): ) with executor(self.config) as tpe: - futures: List[Future[agate.Table]] = [] + futures: List[Future[agate.Table]] = [] # type: ignore for info, schemas in schema_map.items(): for schema in schemas: futures.append( diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index f5a3e3e15..164b41809 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import Optional, TypeVar, List, Tuple, Dict from dataclasses import dataclass, field from dbt.adapters.base.relation import BaseRelation, Policy @@ -33,8 +33,8 @@ class SparkRelation(BaseRelation): is_delta: Optional[bool] = None is_hudi: Optional[bool] = None is_iceberg: Optional[bool] = None - # TODO: make this a dict everywhere - information: Optional[str] = None + columns: List[Tuple[str, str]] = field(default_factory=list) + properties: Dict[str, str] = field(default_factory=dict) def __post_init__(self): if self.database != self.schema and self.database: From 6cb42892e4a9fb56624b32ddfbb834a7916aa8e6 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 9 May 2023 16:31:07 +0200 Subject: [PATCH 2/8] Fix and add a test --- dbt/adapters/spark/impl.py | 69 +++++++++------------ tests/unit/test_adapter.py | 121 ++++++++++++++++++++++++++++--------- 2 files changed, 124 insertions(+), 66 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index ca88c70e3..05939c04e 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -87,7 +87,7 @@ class SparkAdapter(SQLAdapter): "stats:rows:description", "stats:rows:include", ) - INFORMATION_COLUMN_REGEX = re.compile(r" \|-- (.*): (.*) \(nullable = (.*)\)") + INFORMATION_COLUMN_REGEX = re.compile(r"[ | ]* \|-- (.*)\: (.*) \(nullable = (.*)\)") HUDI_METADATA_COLUMNS = [ "_hoodie_commit_time", "_hoodie_commit_seqno", @@ -163,17 +163,20 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo: _schema, name, _, information_blob = row for line in information_blob.split("\n"): if line: - if line.startswith(" |--"): + if " |--" in line: # A column match = self.INFORMATION_COLUMN_REGEX.match(line) if match: columns.append((match[1], match[2])) else: - logger.warning(f"Could not parse: {line}") + logger.warning(f"Could not parse column: {line}") else: # A property parts = line.split(": ", maxsplit=2) - table_properties[parts[0]] = parts[1] + if len(parts) == 2: + table_properties[parts[0]] = parts[1] + else: + logger.warning(f"Found invalid property: {line}") except ValueError: raise dbt.exceptions.DbtRuntimeError( @@ -182,6 +185,28 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo: return RelationInfo(_schema, name, columns, table_properties) + def _parse_describe_table( + self, table_results: agate.Table + ) -> Tuple[List[Tuple[str, str]], Dict[str, str]]: + # Wrap it in an iter, so we continue reading the properties from where we stopped reading columns + table_results_itr = iter(table_results) + + # First the columns + columns = [] + for info_row in table_results_itr: + if info_row[0] == "": + break + columns.append((info_row[0], info_row[1])) + + # Next all the properties + table_properties = {} + for info_row in table_results_itr: + info_type, info_value, _ = info_row + if not info_type.startswith("#") and info_type != "": + table_properties[info_type] = info_value + + return columns, table_properties + def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo: """Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement""" try: @@ -200,41 +225,7 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn logger.debug(f"Error while retrieving information about {table_name}: {e.msg}") table_results = AttrDict() - # idx int - # name string - # - # # Partitioning - # Not partitioned - # - # # Metadata Columns - # _spec_id int - # _partition struct<> - # _file string - # _pos bigint - # _deleted boolean - # - # # Detailed Table Information - # Name sandbox.dbt_tabular3.names - # Location s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb - # Provider iceberg - - # Wrap it in an iter, so we continue reading the properties from where we stopped reading columns - table_results_itr = iter(table_results) - - # First the columns - columns = [] - for info_row in table_results_itr: - if info_row[0] == "": - break - columns.append((info_row[0], info_row[1])) - - # Next all the properties - table_properties = {} - for info_row in table_results_itr: - info_type, info_value, _ = info_row - if not info_type.startswith("#") and info_type != "": - table_properties[info_type] = info_value - + columns, table_properties = self._parse_describe_table(table_results) return RelationInfo(_schema, name, columns, table_properties) def _build_spark_relation_list( diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 3c7fccd35..0b66e8ae6 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,6 +1,7 @@ import unittest from unittest import mock +import agate import dbt.flags as flags from dbt.exceptions import DbtRuntimeError from agate import Row @@ -550,19 +551,25 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) " |-- struct_col: struct (nullable = true)\n" " | |-- struct_inner_col: string (nullable = true)\n" ) - relation = SparkRelation.create( - schema="default_schema", identifier="mytable", type=rel_type, information=information + row = agate.MappedSequence(("default_schema", "mytable", False, information)) + config = self._get_target_http(self.project_cfg) + adapter = SparkAdapter(config) + + tables = adapter._build_spark_relation_list( + row_list=[row], + relation_info_func=adapter._get_relation_information, ) + self.assertEqual(len(tables), 1) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) - self.assertEqual(len(columns), 4) + columns = adapter.parse_describe_extended(tables[0], None) + + self.assertEqual(len(columns), 5) self.assertEqual( columns[0].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "col1", @@ -582,8 +589,8 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) columns[3].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "struct_col", @@ -635,19 +642,26 @@ def test_parse_columns_from_information_with_view_type(self): " |-- struct_col: struct (nullable = true)\n" " | |-- struct_inner_col: string (nullable = true)\n" ) - relation = SparkRelation.create( - schema="default_schema", identifier="myview", type=rel_type, information=information - ) + row = agate.MappedSequence(("default_schema", "myview", False, information)) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) - self.assertEqual(len(columns), 4) + adapter = SparkAdapter(config) + + tables = adapter._build_spark_relation_list( + row_list=[row], + relation_info_func=adapter._get_relation_information, + ) + self.assertEqual(len(tables), 1) + + columns = adapter.parse_describe_extended(tables[0], None) + + self.assertEqual(len(columns), 5) self.assertEqual( columns[1].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "myview", "table_type": rel_type, "table_owner": "root", "column": "col2", @@ -663,8 +677,8 @@ def test_parse_columns_from_information_with_view_type(self): columns[3].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "myview", "table_type": rel_type, "table_owner": "root", "column": "struct_col", @@ -701,19 +715,26 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel " |-- struct_col: struct (nullable = true)\n" " | |-- struct_inner_col: string (nullable = true)\n" ) - relation = SparkRelation.create( - schema="default_schema", identifier="mytable", type=rel_type, information=information - ) + row = agate.MappedSequence(("default_schema", "mytable", False, information)) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) - self.assertEqual(len(columns), 4) + adapter = SparkAdapter(config) + + tables = adapter._build_spark_relation_list( + row_list=[row], + relation_info_func=adapter._get_relation_information, + ) + self.assertEqual(len(tables), 1) + + columns = adapter.parse_describe_extended(tables[0], None) + + self.assertEqual(len(columns), 5) self.assertEqual( columns[2].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "dt", @@ -737,8 +758,8 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel columns[3].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "struct_col", @@ -757,3 +778,49 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "stats:rows:value": 12345678, }, ) + + def test_parse_columns_from_describe_extended(self): + self.maxDiff = None + rows = [ + agate.MappedSequence(["idx", "int", ""]), + agate.MappedSequence(["name", "string", ""]), + agate.MappedSequence(["", "", ""]), + agate.MappedSequence(["# Partitioning", "", ""]), + agate.MappedSequence(["Not partitioned", "", ""]), + agate.MappedSequence(["", "", ""]), + agate.MappedSequence(["# Metadata Columns", "", ""]), + agate.MappedSequence(["_spec_id", "int", ""]), + agate.MappedSequence(["_partition", "struct<>", ""]), + agate.MappedSequence(["_file", "string", ""]), + agate.MappedSequence(["_pos", "bigint", ""]), + agate.MappedSequence(["_deleted", "boolean", ""]), + agate.MappedSequence(["", "", ""]), + agate.MappedSequence(["# Detailed Table Information", "", ""]), + agate.MappedSequence(["Name", "sandbox.dbt_tabular3.names", ""]), + agate.MappedSequence( + [ + "Location", + "s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb", + "", + ] + ), + agate.MappedSequence(["Provider", "iceberg", ""]), + ] + + config = self._get_target_http(self.project_cfg) + adapter = SparkAdapter(config) + + columns, properties = adapter._parse_describe_table(rows) + + assert columns == [("idx", "int"), ("name", "string")] + assert properties == { + "Location": "s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb", + "Name": "sandbox.dbt_tabular3.names", + "Not partitioned": "", + "Provider": "iceberg", + "_deleted": "boolean", + "_file": "string", + "_partition": "struct<>", + "_pos": "bigint", + "_spec_id": "int", + } From a4238034a2d926b5745c2197fb18b0c2f9c96ba2 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 9 May 2023 22:02:26 +0200 Subject: [PATCH 3/8] Cleanup and tests --- dbt/adapters/spark/impl.py | 78 ++++++--------------------- tests/unit/test_adapter.py | 105 ++++++++++++++++++++++++------------- 2 files changed, 84 insertions(+), 99 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 05939c04e..d82b6e657 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -24,18 +24,14 @@ from dbt.adapters.base import BaseRelation from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger -from dbt.flags import get_flags from dbt.utils import executor, AttrDict logger = AdapterLogger("Spark") -GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME = "get_columns_in_relation_raw" LIST_SCHEMAS_MACRO_NAME = "list_schemas" LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" LIST_RELATIONS_SHOW_TABLES_MACRO_NAME = "list_relations_show_tables_without_caching" DESCRIBE_TABLE_EXTENDED_MACRO_NAME = "describe_table_extended_without_caching" -DROP_RELATION_MACRO_NAME = "drop_relation" -FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties" KEY_TABLE_OWNER = "Owner" KEY_TABLE_STATISTICS = "Statistics" @@ -106,41 +102,29 @@ def date_function(cls) -> str: return "current_timestamp()" @classmethod - def convert_text_type(cls, agate_table, col_idx): + def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "string" @classmethod - def convert_number_type(cls, agate_table, col_idx): + def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "double" if decimals else "bigint" @classmethod - def convert_date_type(cls, agate_table, col_idx): + def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "date" @classmethod - def convert_time_type(cls, agate_table, col_idx): + def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "time" @classmethod - def convert_datetime_type(cls, agate_table, col_idx): + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" def quote(self, identifier): return "`{}`".format(identifier) - def add_schema_to_cache(self, schema) -> str: - """Cache a new schema in dbt. It will show up in `list relations`.""" - if schema is None: - name = self.nice_connection_name() - raise dbt.exceptions.CompilationError( - "Attempted to cache a null schema for {}".format(name) - ) - if get_flags().USE_CACHE: - self.cache.add_schema(None, schema) - # so jinja doesn't render things - return "" - def _get_relation_information(self, row: agate.Row) -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: @@ -194,15 +178,15 @@ def _parse_describe_table( # First the columns columns = [] for info_row in table_results_itr: - if info_row[0] == "": + if info_row[0] is None or info_row[0] == "" or info_row[0].startswith("#"): break - columns.append((info_row[0], info_row[1])) + columns.append((info_row[0], str(info_row[1]))) # Next all the properties table_properties = {} for info_row in table_results_itr: - info_type, info_value, _ = info_row - if not info_type.startswith("#") and info_type != "": + info_type, info_value = info_row[:2] + if info_type is not None and not info_type.startswith("#") and info_type != "": table_properties[info_type] = info_value return columns, table_properties @@ -240,21 +224,18 @@ def _build_spark_relation_list( rel_type: RelationType = ( RelationType.View - if relation.properties.get("type") == "VIEW" + if relation.properties.get("Type") == "VIEW" else RelationType.Table ) - is_delta: bool = relation.properties.get("provider") == "delta" - is_hudi: bool = relation.properties.get("provider") == "hudi" - is_iceberg: bool = relation.properties.get("provider") == "iceberg" relations.append( self.Relation.create( schema=relation.table_schema, identifier=relation.table_name, type=rel_type, - is_delta=is_delta, - is_iceberg=is_iceberg, - is_hudi=is_hudi, + is_delta=relation.properties.get("Provider") == "delta", + is_iceberg=relation.properties.get("Provider") == "iceberg", + is_hudi=relation.properties.get("Provider") == "hudi", columns=relation.columns, properties=relation.properties, ) @@ -308,9 +289,8 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ return super().get_relation(database, schema, identifier) - def parse_describe_extended( - self, relation: SparkRelation, raw_rows: AttrDict - ) -> List[SparkColumn]: + def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: + assert isinstance(relation, SparkRelation), f"Expected SparkRelation, got: {relation}" # Convert the Row to a dict raw_table_stats = relation.properties.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) @@ -327,29 +307,9 @@ def parse_describe_extended( dtype=column_type, ) for idx, (column_name, column_type) in enumerate(relation.columns) + if column_name not in self.HUDI_METADATA_COLUMNS ] - def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: - columns = [] - try: - rows: AttrDict = self.execute_macro( - GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} - ) - columns = self.parse_describe_extended(relation, rows) # type: ignore - except dbt.exceptions.DbtRuntimeError as e: - # spark would throw error when table doesn't exist, where other - # CDW would just return and empty list, normalizing the behavior here - errmsg = getattr(e, "msg", "") - found_msgs = (msg in errmsg for msg in TABLE_OR_VIEW_NOT_FOUND_MESSAGES) - if any(found_msgs): - pass - else: - raise e - - # strip hudi metadata columns. - columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] - return columns - def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]: owner = relation.properties.get(KEY_TABLE_OWNER, "") columns = [] @@ -382,12 +342,6 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, as_dict["table_database"] = None yield as_dict - def get_properties(self, relation: Relation) -> Dict[str, str]: - properties = self.execute_macro( - FETCH_TBL_PROPERTIES_MACRO_NAME, kwargs={"relation": relation} - ) - return dict(properties) - def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 0b66e8ae6..0e591f40d 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -7,6 +7,7 @@ from agate import Row from pyhive import hive from dbt.adapters.spark import SparkAdapter, SparkRelation +from dbt.adapters.spark.impl import RelationInfo, KEY_TABLE_OWNER from .utils import config_from_parts_or_dicts @@ -321,10 +322,15 @@ def test_parse_relation(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) - self.assertEqual(len(rows), 4) + adapter = SparkAdapter(config) + columns, properties = adapter._parse_describe_table(input_cols) + relation_info = adapter._build_spark_relation_list( + columns, lambda a: RelationInfo(relation.schema, relation.name, columns, properties) + ) + columns = adapter.get_columns_in_relation(relation_info[0]) + self.assertEqual(len(columns), 4) self.assertEqual( - rows[0].to_column_dict(omit_none=False), + columns[0].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -341,7 +347,7 @@ def test_parse_relation(self): ) self.assertEqual( - rows[1].to_column_dict(omit_none=False), + columns[1].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -358,7 +364,7 @@ def test_parse_relation(self): ) self.assertEqual( - rows[2].to_column_dict(omit_none=False), + columns[2].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -375,7 +381,7 @@ def test_parse_relation(self): ) self.assertEqual( - rows[3].to_column_dict(omit_none=False), + columns[3].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -407,12 +413,10 @@ def test_parse_relation_with_integer_owner(self): ("Owner", 1234), ] - input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + _, properties = SparkAdapter(config)._parse_describe_table(plain_rows) - self.assertEqual(rows[0].to_column_dict().get("table_owner"), "1234") + self.assertEqual(properties.get(KEY_TABLE_OWNER), "1234") def test_parse_relation_with_statistics(self): self.maxDiff = None @@ -443,35 +447,62 @@ def test_parse_relation_with_statistics(self): ("Partition Provider", "Catalog"), ] - input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) - self.assertEqual(len(rows), 1) - self.assertEqual( - rows[0].to_column_dict(omit_none=False), - { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "column": "col1", - "column_index": 0, - "dtype": "decimal(22,0)", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "stats:bytes:description": "", - "stats:bytes:include": True, - "stats:bytes:label": "bytes", - "stats:bytes:value": 1109049927, - "stats:rows:description": "", - "stats:rows:include": True, - "stats:rows:label": "rows", - "stats:rows:value": 14093476, - }, + columns, properties = SparkAdapter(config)._parse_describe_table(plain_rows) + spark_relation = SparkRelation.create( + schema=relation.schema, + identifier=relation.name, + type=rel_type, + columns=columns, + properties=properties, ) + rows = SparkAdapter(config).parse_columns_from_information(spark_relation) + self.assertEqual(len(rows), 1) + # self.assertEqual( + # rows[0].to_column_dict(omit_none=False), + # { + # "table_database": None, + # "table_schema": relation.schema, + # "table_name": relation.name, + # "table_type": rel_type, + # "table_owner": "root", + # "column": "col1", + # "column_index": 0, + # "dtype": "decimal(22,0)", + # "numeric_scale": None, + # "numeric_precision": None, + # "char_size": None, + # "stats:bytes:description": "", + # "stats:bytes:include": True, + # "stats:bytes:label": "bytes", + # "stats:bytes:value": 1109049927, + # "stats:rows:description": "", + # "stats:rows:include": True, + # "stats:rows:label": "rows", + # "stats:rows:value": 14093476, + # }, + # ) + assert rows[0].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "col1", + "column_index": 0, + "dtype": "decimal(22,0)", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1109049927, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 14093476, + } def test_relation_with_database(self): config = self._get_target_http(self.project_cfg) From 036f2788c0abe98f6da53bc4f22c4d7405453b04 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 9 May 2023 22:16:36 +0200 Subject: [PATCH 4/8] Cleanup --- dbt/adapters/spark/impl.py | 16 +--- tests/unit/test_adapter.py | 157 +++++++++++++++---------------------- 2 files changed, 64 insertions(+), 109 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index d82b6e657..01b151686 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -128,20 +128,6 @@ def quote(self, identifier): def _get_relation_information(self, row: agate.Row) -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: - # Example lines: - # Database: dbt_schema - # Table: names - # Owner: fokkodriesprong - # Created Time: Mon May 08 18:06:47 CEST 2023 - # Last Access: UNKNOWN - # Created By: Spark 3.3.2 - # Type: MANAGED - # Provider: hive - # Table Properties: [transient_lastDdlTime=1683562007] - # Statistics: 16 bytes - # Schema: root - # |-- idx: integer (nullable = false) - # |-- name: string (nullable = false) table_properties = {} columns = [] _schema, name, _, information_blob = row @@ -187,7 +173,7 @@ def _parse_describe_table( for info_row in table_results_itr: info_type, info_value = info_row[:2] if info_type is not None and not info_type.startswith("#") and info_type != "": - table_properties[info_type] = info_value + table_properties[info_type] = str(info_value) return columns, table_properties diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 0e591f40d..3bec3df33 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -458,51 +458,30 @@ def test_parse_relation_with_statistics(self): ) rows = SparkAdapter(config).parse_columns_from_information(spark_relation) self.assertEqual(len(rows), 1) - # self.assertEqual( - # rows[0].to_column_dict(omit_none=False), - # { - # "table_database": None, - # "table_schema": relation.schema, - # "table_name": relation.name, - # "table_type": rel_type, - # "table_owner": "root", - # "column": "col1", - # "column_index": 0, - # "dtype": "decimal(22,0)", - # "numeric_scale": None, - # "numeric_precision": None, - # "char_size": None, - # "stats:bytes:description": "", - # "stats:bytes:include": True, - # "stats:bytes:label": "bytes", - # "stats:bytes:value": 1109049927, - # "stats:rows:description": "", - # "stats:rows:include": True, - # "stats:rows:label": "rows", - # "stats:rows:value": 14093476, - # }, - # ) - assert rows[0].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "column": "col1", - "column_index": 0, - "dtype": "decimal(22,0)", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "stats:bytes:description": "", - "stats:bytes:include": True, - "stats:bytes:label": "bytes", - "stats:bytes:value": 1109049927, - "stats:rows:description": "", - "stats:rows:include": True, - "stats:rows:label": "rows", - "stats:rows:value": 14093476, - } + self.assertEqual( + rows[0].to_column_dict(omit_none=False), + { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "col1", + "column_index": 0, + "dtype": "decimal(22,0)", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1109049927, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 14093476, + }, + ) def test_relation_with_database(self): config = self._get_target_http(self.project_cfg) @@ -592,7 +571,19 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) ) self.assertEqual(len(tables), 1) - columns = adapter.parse_describe_extended(tables[0], None) + table = tables[0] + + assert isinstance(table, SparkRelation) + + columns = adapter.get_columns_in_relation( + SparkRelation.create( + type=rel_type, + schema="default_schema", + identifier="mytable", + columns=table.columns, + properties=table.properties, + ) + ) self.assertEqual(len(columns), 5) self.assertEqual( @@ -684,7 +675,19 @@ def test_parse_columns_from_information_with_view_type(self): ) self.assertEqual(len(tables), 1) - columns = adapter.parse_describe_extended(tables[0], None) + table = tables[0] + + assert isinstance(table, SparkRelation) + + columns = adapter.get_columns_in_relation( + SparkRelation.create( + type=rel_type, + schema="default_schema", + identifier="myview", + columns=table.columns, + properties=table.properties, + ) + ) self.assertEqual(len(columns), 5) self.assertEqual( @@ -757,7 +760,19 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel ) self.assertEqual(len(tables), 1) - columns = adapter.parse_describe_extended(tables[0], None) + table = tables[0] + + assert isinstance(table, SparkRelation) + + columns = adapter.get_columns_in_relation( + SparkRelation.create( + type=rel_type, + schema="default_schema", + identifier="mytable", + columns=table.columns, + properties=table.properties, + ) + ) self.assertEqual(len(columns), 5) self.assertEqual( @@ -809,49 +824,3 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "stats:rows:value": 12345678, }, ) - - def test_parse_columns_from_describe_extended(self): - self.maxDiff = None - rows = [ - agate.MappedSequence(["idx", "int", ""]), - agate.MappedSequence(["name", "string", ""]), - agate.MappedSequence(["", "", ""]), - agate.MappedSequence(["# Partitioning", "", ""]), - agate.MappedSequence(["Not partitioned", "", ""]), - agate.MappedSequence(["", "", ""]), - agate.MappedSequence(["# Metadata Columns", "", ""]), - agate.MappedSequence(["_spec_id", "int", ""]), - agate.MappedSequence(["_partition", "struct<>", ""]), - agate.MappedSequence(["_file", "string", ""]), - agate.MappedSequence(["_pos", "bigint", ""]), - agate.MappedSequence(["_deleted", "boolean", ""]), - agate.MappedSequence(["", "", ""]), - agate.MappedSequence(["# Detailed Table Information", "", ""]), - agate.MappedSequence(["Name", "sandbox.dbt_tabular3.names", ""]), - agate.MappedSequence( - [ - "Location", - "s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb", - "", - ] - ), - agate.MappedSequence(["Provider", "iceberg", ""]), - ] - - config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config) - - columns, properties = adapter._parse_describe_table(rows) - - assert columns == [("idx", "int"), ("name", "string")] - assert properties == { - "Location": "s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb", - "Name": "sandbox.dbt_tabular3.names", - "Not partitioned": "", - "Provider": "iceberg", - "_deleted": "boolean", - "_file": "string", - "_partition": "struct<>", - "_pos": "bigint", - "_spec_id": "int", - } From f697b80eb885a8cae7996d7899d25abd0cd14988 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 9 May 2023 22:41:48 +0200 Subject: [PATCH 5/8] Improve error reporting --- dbt/adapters/spark/impl.py | 39 ++++++++++++++++++++++++++++++-------- tests/unit/test_adapter.py | 6 +++--- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 01b151686..26b9ed1dd 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -28,6 +28,7 @@ logger = AdapterLogger("Spark") +GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME = "get_columns_in_relation_raw" LIST_SCHEMAS_MACRO_NAME = "list_schemas" LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" LIST_RELATIONS_SHOW_TABLES_MACRO_NAME = "list_relations_show_tables_without_caching" @@ -155,7 +156,7 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo: return RelationInfo(_schema, name, columns, table_properties) - def _parse_describe_table( + def _parse_describe_table_extended( self, table_results: agate.Table ) -> Tuple[List[Tuple[str, str]], Dict[str, str]]: # Wrap it in an iter, so we continue reading the properties from where we stopped reading columns @@ -195,7 +196,7 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn logger.debug(f"Error while retrieving information about {table_name}: {e.msg}") table_results = AttrDict() - columns, table_properties = self._parse_describe_table(table_results) + columns, table_properties = self._parse_describe_table_extended(table_results) return RelationInfo(_schema, name, columns, table_properties) def _build_spark_relation_list( @@ -276,9 +277,29 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ return super().get_relation(database, schema, identifier) def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: - assert isinstance(relation, SparkRelation), f"Expected SparkRelation, got: {relation}" + assert isinstance(relation, SparkRelation) + if relation.columns is not None and len(relation.columns) > 0: + columns = relation.columns + properties = relation.properties + else: + try: + describe_extended_result = self.execute_macro( + GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} + ) + columns, properties = self._parse_describe_table_extended(describe_extended_result) + except dbt.exceptions.DbtRuntimeError as e: + # spark would throw error when table doesn't exist, where other + # CDW would just return and empty list, normalizing the behavior here + errmsg = getattr(e, "msg", "") + found_msgs = (msg in errmsg for msg in TABLE_OR_VIEW_NOT_FOUND_MESSAGES) + if any(found_msgs): + columns = [] + properties = {} + else: + raise e + # Convert the Row to a dict - raw_table_stats = relation.properties.get(KEY_TABLE_STATISTICS) + raw_table_stats = properties.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) return [ SparkColumn( @@ -286,13 +307,13 @@ def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: table_schema=relation.schema, table_name=relation.name, table_type=relation.type, - table_owner=relation.properties.get(KEY_TABLE_OWNER, ""), + table_owner=properties.get(KEY_TABLE_OWNER, ""), table_stats=table_stats, column=column_name, column_index=idx, dtype=column_type, ) - for idx, (column_name, column_type) in enumerate(relation.columns) + for idx, (column_name, column_type) in enumerate(columns) if column_name not in self.HUDI_METADATA_COLUMNS ] @@ -385,19 +406,21 @@ def get_rows_different_sql( column_names: Optional[List[str]] = None, except_operator: str = "EXCEPT", ) -> str: - """Generate SQL for a query that returns a single row with a two + """Generate SQL for a query that returns a single row with two columns: the number of rows that are different between the two relations and the number of mismatched rows. """ # This method only really exists for test reasons. names: List[str] - if column_names is None: + if not column_names: columns = self.get_columns_in_relation(relation_a) names = sorted((self.quote(c.name) for c in columns)) else: names = sorted((self.quote(n) for n in column_names)) columns_csv = ", ".join(names) + assert columns_csv, f"Could not determine columns for: {relation_a}" + sql = COLUMNS_EQUAL_SQL.format( columns=columns_csv, relation_a=str(relation_a), diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 3bec3df33..aa972913e 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -323,7 +323,7 @@ def test_parse_relation(self): config = self._get_target_http(self.project_cfg) adapter = SparkAdapter(config) - columns, properties = adapter._parse_describe_table(input_cols) + columns, properties = adapter._parse_describe_table_extended(input_cols) relation_info = adapter._build_spark_relation_list( columns, lambda a: RelationInfo(relation.schema, relation.name, columns, properties) ) @@ -414,7 +414,7 @@ def test_parse_relation_with_integer_owner(self): ] config = self._get_target_http(self.project_cfg) - _, properties = SparkAdapter(config)._parse_describe_table(plain_rows) + _, properties = SparkAdapter(config)._parse_describe_table_extended(plain_rows) self.assertEqual(properties.get(KEY_TABLE_OWNER), "1234") @@ -448,7 +448,7 @@ def test_parse_relation_with_statistics(self): ] config = self._get_target_http(self.project_cfg) - columns, properties = SparkAdapter(config)._parse_describe_table(plain_rows) + columns, properties = SparkAdapter(config)._parse_describe_table_extended(plain_rows) spark_relation = SparkRelation.create( schema=relation.schema, identifier=relation.name, From f97c1617089285f7124b46d4f5e1e096e54a564a Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 10 May 2023 15:39:38 +0200 Subject: [PATCH 6/8] Cleanup --- dbt/adapters/spark/impl.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 26b9ed1dd..e1618dbb6 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -24,6 +24,7 @@ from dbt.adapters.base import BaseRelation from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger +from dbt.flags import get_flags from dbt.utils import executor, AttrDict logger = AdapterLogger("Spark") @@ -33,6 +34,8 @@ LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" LIST_RELATIONS_SHOW_TABLES_MACRO_NAME = "list_relations_show_tables_without_caching" DESCRIBE_TABLE_EXTENDED_MACRO_NAME = "describe_table_extended_without_caching" +DROP_RELATION_MACRO_NAME = "drop_relation" +FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties" KEY_TABLE_OWNER = "Owner" KEY_TABLE_STATISTICS = "Statistics" @@ -103,29 +106,41 @@ def date_function(cls) -> str: return "current_timestamp()" @classmethod - def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_text_type(cls, agate_table, col_idx): return "string" @classmethod - def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_number_type(cls, agate_table, col_idx): decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "double" if decimals else "bigint" @classmethod - def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_date_type(cls, agate_table, col_idx): return "date" @classmethod - def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_time_type(cls, agate_table, col_idx): return "time" @classmethod - def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_datetime_type(cls, agate_table, col_idx): return "timestamp" def quote(self, identifier): return "`{}`".format(identifier) + def add_schema_to_cache(self, schema) -> str: + """Cache a new schema in dbt. It will show up in `list relations`.""" + if schema is None: + name = self.nice_connection_name() + raise dbt.exceptions.CompilationError( + "Attempted to cache a null schema for {}".format(name) + ) + if get_flags().USE_CACHE: + self.cache.add_schema(None, schema) + # so jinja doesn't render things + return "" + def _get_relation_information(self, row: agate.Row) -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: @@ -349,6 +364,12 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, as_dict["table_database"] = None yield as_dict + def get_properties(self, relation: Relation) -> Dict[str, str]: + properties = self.execute_macro( + FETCH_TBL_PROPERTIES_MACRO_NAME, kwargs={"relation": relation} + ) + return dict(properties) + def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: From 85395df96969d43ce11c856af01f8c4aac2b1f6a Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 29 Jun 2023 18:36:37 +0200 Subject: [PATCH 7/8] Update dbt/adapters/spark/impl.py --- dbt/adapters/spark/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 241e1b965..4e489366b 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -371,7 +371,7 @@ def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]] ) with executor(self.config) as tpe: - futures: List[Future[agate.Table]] = [] # type: ignore + futures: List[Future[agate.Table]] = [] for info, schemas in schema_map.items(): for schema in schemas: futures.append( From 7a9c50bd88360e19848dc2cf1b194380db86965a Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 7 Aug 2023 14:16:38 +0200 Subject: [PATCH 8/8] Update dbt/adapters/spark/impl.py Co-authored-by: Cor --- dbt/adapters/spark/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 4e489366b..6a571ec4b 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -433,7 +433,7 @@ def get_rows_different_sql( names = sorted((self.quote(n) for n in column_names)) columns_csv = ", ".join(names) - assert columns_csv, f"Could not determine columns for: {relation_a}" + assert columns_csv, f"Could not find columns for: {relation_a}" sql = COLUMNS_EQUAL_SQL.format( columns=columns_csv,