22import math
33
44from .database_types import *
5- from .base import TIMESTAMP_PRECISION_POS , Database , import_helper , _query_conn
5+ from .base import Database , import_helper , _query_conn , parse_table_name
66
77
88@import_helper ("databricks" )
@@ -84,26 +84,27 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
8484
8585 resulted_rows = []
8686 for row in rows :
87- type_cls = self .TYPE_CLASSES .get (str (row .TYPE_NAME ), UnknownColType )
87+ row_type = 'DECIMAL' if row .DATA_TYPE == 3 else row .TYPE_NAME
88+ type_cls = self .TYPE_CLASSES .get (row_type , UnknownColType )
8889
8990 if issubclass (type_cls , Integer ):
90- row = (row .COLUMN_NAME , row . TYPE_NAME , None , None , 0 )
91+ row = (row .COLUMN_NAME , row_type , None , None , 0 )
9192
9293 elif issubclass (type_cls , Float ):
9394 numeric_precision = math .ceil (row .DECIMAL_DIGITS / math .log (2 , 10 ))
94- row = (row .COLUMN_NAME , row . TYPE_NAME , None , numeric_precision , None )
95+ row = (row .COLUMN_NAME , row_type , None , numeric_precision , None )
9596
9697 elif issubclass (type_cls , Decimal ):
9798 # TYPE_NAME has a format DECIMAL(x,y)
9899 items = row .TYPE_NAME [8 :].rstrip (')' ).split (',' )
99100 numeric_precision , numeric_scale = int (items [0 ]), int (items [1 ])
100- row = (row .COLUMN_NAME , row . TYPE_NAME , None , numeric_precision , numeric_scale )
101+ row = (row .COLUMN_NAME , row_type , None , numeric_precision , numeric_scale )
101102
102103 elif issubclass (type_cls , Timestamp ):
103- row = (row .COLUMN_NAME , row . TYPE_NAME , row .DECIMAL_DIGITS , None , None )
104+ row = (row .COLUMN_NAME , row_type , row .DECIMAL_DIGITS , None , None )
104105
105106 else :
106- row = (row .COLUMN_NAME , row . TYPE_NAME , None , None , None )
107+ row = (row .COLUMN_NAME , row_type , None , None , None )
107108
108109 resulted_rows .append (row )
109110 return {row [0 ]: self ._parse_type (path , * row ) for row in resulted_rows }
@@ -121,5 +122,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
121122 def normalize_number (self , value : str , coltype : NumericType ) -> str :
122123 return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
123124
125+ def parse_table_name (self , name : str ) -> DbPath :
126+ path = parse_table_name (name )
127+ return self ._normalize_table_path (path )
128+
124129 def close (self ):
125130 self ._conn .close ()
0 commit comments