Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 0680fac

Browse files
committed
Fix incorrect parsing of Decimal type and add UUID support
1 parent f577dc0 commit 0680fac

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

data_diff/databases/databricks.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33

44
from .database_types import *
5-
from .base import TIMESTAMP_PRECISION_POS, Database, import_helper, _query_conn
5+
from .base import Database, import_helper, _query_conn, parse_table_name
66

77

88
@import_helper("databricks")
@@ -84,26 +84,27 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
8484

8585
resulted_rows = []
8686
for row in rows:
87-
type_cls = self.TYPE_CLASSES.get(str(row.TYPE_NAME), UnknownColType)
87+
row_type = 'DECIMAL' if row.DATA_TYPE == 3 else row.TYPE_NAME
88+
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)
8889

8990
if issubclass(type_cls, Integer):
90-
row = (row.COLUMN_NAME, row.TYPE_NAME, None, None, 0)
91+
row = (row.COLUMN_NAME, row_type, None, None, 0)
9192

9293
elif issubclass(type_cls, Float):
9394
numeric_precision = math.ceil(row.DECIMAL_DIGITS / math.log(2, 10))
94-
row = (row.COLUMN_NAME, row.TYPE_NAME, None, numeric_precision, None)
95+
row = (row.COLUMN_NAME, row_type, None, numeric_precision, None)
9596

9697
elif issubclass(type_cls, Decimal):
9798
# TYPE_NAME has a format DECIMAL(x,y)
9899
items = row.TYPE_NAME[8:].rstrip(')').split(',')
99100
numeric_precision, numeric_scale = int(items[0]), int(items[1])
100-
row = (row.COLUMN_NAME, row.TYPE_NAME, None, numeric_precision, numeric_scale)
101+
row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale)
101102

102103
elif issubclass(type_cls, Timestamp):
103-
row = (row.COLUMN_NAME, row.TYPE_NAME, row.DECIMAL_DIGITS, None, None)
104+
row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None)
104105

105106
else:
106-
row = (row.COLUMN_NAME, row.TYPE_NAME, None, None, None)
107+
row = (row.COLUMN_NAME, row_type, None, None, None)
107108

108109
resulted_rows.append(row)
109110
return {row[0]: self._parse_type(path, *row) for row in resulted_rows}
@@ -121,5 +122,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
121122
def normalize_number(self, value: str, coltype: NumericType) -> str:
122123
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
123124

125+
def parse_table_name(self, name: str) -> DbPath:
126+
path = parse_table_name(name)
127+
return self._normalize_table_path(path)
128+
124129
def close(self):
125130
self._conn.close()

0 commit comments

Comments
 (0)