1313 ColType ,
1414 UnknownColType ,
1515)
16- from .base import MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , Database , import_helper , parse_table_name
16+ from .base import MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , ThreadedDatabase , import_helper , parse_table_name
1717
1818
1919@import_helper (text = "You can install it using 'pip install databricks-sql-connector'" )
@@ -68,43 +68,45 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
6868 return max (super ()._convert_db_precision_to_digits (p ) - 1 , 0 )
6969
7070
71- class Databricks (Database ):
71+ class Databricks (ThreadedDatabase ):
7272 dialect = Dialect ()
7373
74- def __init__ (
75- self ,
76- http_path : str ,
77- access_token : str ,
78- server_hostname : str ,
79- catalog : str = "hive_metastore" ,
80- schema : str = "default" ,
81- ** kwargs ,
82- ):
83- databricks = import_databricks ()
84-
85- self ._conn = databricks .sql .connect (
86- server_hostname = server_hostname , http_path = http_path , access_token = access_token , catalog = catalog
87- )
88-
74+ def __init__ (self , * , thread_count , ** kw ):
8975 logging .getLogger ("databricks.sql" ).setLevel (logging .WARNING )
9076
91- self .catalog = catalog
92- self .default_schema = schema
93- self . kwargs = kwargs
77+ self ._args = kw
78+ self .default_schema = kw . get ( ' schema' , 'hive_metastore' )
79+ super (). __init__ ( thread_count = thread_count )
9480
95- def _query (self , sql_code : str ) -> list :
96- "Uses the standard SQL cursor interface"
97- return self ._query_conn (self ._conn , sql_code )
81+ def create_connection (self ):
82+ databricks = import_databricks ()
83+
84+ try :
85+ return databricks .sql .connect (
86+ server_hostname = self ._args ['server_hostname' ],
87+ http_path = self ._args ['http_path' ],
88+ access_token = self ._args ['access_token' ],
89+ catalog = self ._args ['catalog' ],
90+ )
91+ except databricks .sql .exc .Error as e :
92+ raise ConnectionError (* e .args ) from e
9893
9994 def query_table_schema (self , path : DbPath ) -> Dict [str , tuple ]:
10095 # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
10196 # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
10297 # So, to obtain information about schema, we should use another approach.
10398
99+ conn = self .create_connection ()
100+
104101 schema , table = self ._normalize_table_path (path )
105- with self ._conn .cursor () as cursor :
106- cursor .columns (catalog_name = self .catalog , schema_name = schema , table_name = table )
107- rows = cursor .fetchall ()
102+ with conn .cursor () as cursor :
103+ cursor .columns (catalog_name = self ._args ['catalog' ], schema_name = schema , table_name = table )
104+ try :
105+ rows = cursor .fetchall ()
106+ except :
107+ rows = None
108+ finally :
109+ conn .close ()
108110 if not rows :
109111 raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
110112
@@ -121,7 +123,7 @@ def _process_table_schema(
121123 resulted_rows = []
122124 for row in rows :
123125 row_type = "DECIMAL" if row [1 ].startswith ("DECIMAL" ) else row [1 ]
124- type_cls = self .TYPE_CLASSES .get (row_type , UnknownColType )
126+ type_cls = self .dialect . TYPE_CLASSES .get (row_type , UnknownColType )
125127
126128 if issubclass (type_cls , Integer ):
127129 row = (row [0 ], row_type , None , None , 0 )
@@ -152,9 +154,6 @@ def parse_table_name(self, name: str) -> DbPath:
152154 path = parse_table_name (name )
153155 return self ._normalize_table_path (path )
154156
155- def close (self ):
156- self ._conn .close ()
157-
158157 @property
159158 def is_autocommit (self ) -> bool :
160159 return True
0 commit comments