33from itertools import zip_longest
44import re
55from abc import ABC , abstractmethod
6- from runtype import dataclass
76import logging
8- from typing import Sequence , Tuple , Optional , List
7+ from typing import Sequence , Tuple , Optional , List , Type
98from concurrent .futures import ThreadPoolExecutor
109import threading
1110from typing import Dict
12-
1311import dsnparse
1412import sys
1513
14+ from runtype import dataclass
15+
1616from .sql import DbPath , SqlOrStr , Compiler , Explain , Select
17+ from .database_types import *
1718
1819
1920logger = logging .getLogger ("database" )
@@ -109,149 +110,6 @@ def _query_conn(conn, sql_code: str) -> list:
109110 return c .fetchall ()
110111
111112
112- class ColType :
113- pass
114-
115-
116- @dataclass
117- class PrecisionType (ColType ):
118- precision : Optional [int ]
119- rounds : bool
120-
121-
122- class TemporalType (PrecisionType ):
123- pass
124-
125-
126- class Timestamp (TemporalType ):
127- pass
128-
129-
130- class TimestampTZ (TemporalType ):
131- pass
132-
133-
134- class Datetime (TemporalType ):
135- pass
136-
137-
138- @dataclass
139- class NumericType (ColType ):
140- # 'precision' signifies how many fractional digits (after the dot) we want to compare
141- precision : int
142-
143-
144- class Float (NumericType ):
145- pass
146-
147-
148- class Decimal (NumericType ):
149- pass
150-
151-
152- @dataclass
153- class Integer (Decimal ):
154- def __post_init__ (self ):
155- assert self .precision == 0
156-
157-
158- @dataclass
159- class UnknownColType (ColType ):
160- text : str
161-
162-
163- class AbstractDatabase (ABC ):
164- @abstractmethod
165- def quote (self , s : str ):
166- "Quote SQL name (implementation specific)"
167- ...
168-
169- @abstractmethod
170- def to_string (self , s : str ) -> str :
171- "Provide SQL for casting a column to string"
172- ...
173-
174- @abstractmethod
175- def md5_to_int (self , s : str ) -> str :
176- "Provide SQL for computing md5 and returning an int"
177- ...
178-
179- @abstractmethod
180- def _query (self , sql_code : str ) -> list :
181- "Send query to database and return result"
182- ...
183-
184- @abstractmethod
185- def select_table_schema (self , path : DbPath ) -> str :
186- "Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"
187- ...
188-
189- @abstractmethod
190- def query_table_schema (self , path : DbPath , filter_columns : Optional [Sequence [str ]] = None ) -> Dict [str , ColType ]:
191- "Query the table for its schema for table in 'path', and return {column: type}"
192- ...
193-
194- @abstractmethod
195- def parse_table_name (self , name : str ) -> DbPath :
196- "Parse the given table name into a DbPath"
197- ...
198-
199- @abstractmethod
200- def close (self ):
201- "Close connection(s) to the database instance. Querying will stop functioning."
202- ...
203-
204- @abstractmethod
205- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
206- """Creates an SQL expression, that converts 'value' to a normalized timestamp.
207-
208- The returned expression must accept any SQL datetime/timestamp, and return a string.
209-
210- Date format: "YYYY-MM-DD HH:mm:SS.FFFFFF"
211-
212- Precision of dates should be rounded up/down according to coltype.rounds
213- """
214- ...
215-
216- @abstractmethod
217- def normalize_number (self , value : str , coltype : ColType ) -> str :
218- """Creates an SQL expression, that converts 'value' to a normalized number.
219-
220- The returned expression must accept any SQL int/numeric/float, and return a string.
221-
222- - Floats/Decimals are expected in the format
223- "I.P"
224-
225- Where I is the integer part of the number (as many digits as necessary),
226- and must be at least one digit (0).
227- P is the fractional digits, the amount of which is specified with
228- coltype.precision. Trailing zeroes may be necessary.
229- If P is 0, the dot is omitted.
230-
231- Note: This precision is different than the one used by databases. For decimals,
232- it's the same as ``numeric_scale``, and for floats, who use binary precision,
233- it can be calculated as ``log10(2**numeric_precision)``.
234- """
235- ...
236-
237- def normalize_value_by_type (self , value : str , coltype : ColType ) -> str :
238- """Creates an SQL expression, that converts 'value' to a normalized representation.
239-
240- The returned expression must accept any SQL value, and return a string.
241-
242- The default implementation dispatches to a method according to ``coltype``:
243-
244- TemporalType -> normalize_timestamp()
245- NumericType -> normalize_number()
246- -else- -> to_string()
247-
248- """
249- if isinstance (coltype , TemporalType ):
250- return self .normalize_timestamp (value , coltype )
251- elif isinstance (coltype , NumericType ):
252- return self .normalize_number (value , coltype )
253- return self .to_string (f"{ value } " )
254-
255113
256114class Database (AbstractDatabase ):
257115 """Base abstract class for databases.
@@ -261,8 +119,8 @@ class Database(AbstractDatabase):
261119 Instanciated using :meth:`~data_diff.connect_to_uri`
262120 """
263121
264- DATETIME_TYPES = {}
265- default_schema = None
122+ DATETIME_TYPES : Dict [ str , type ] = {}
123+ default_schema : str = None
266124
267125 @property
268126 def name (self ):
@@ -412,9 +270,6 @@ def _query_in_worker(self, sql_code: str):
412270 raise self ._init_error
413271 return _query_conn (self .thread_local .conn , sql_code )
414272
415- def close (self ):
416- self ._queue .shutdown (True )
417-
418273 @abstractmethod
419274 def create_connection (self ):
420275 ...
@@ -481,7 +336,7 @@ def md5_to_int(self, s: str) -> str:
481336 def to_string (self , s : str ):
482337 return f"{ s } ::varchar"
483338
484- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
339+ def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
485340 if coltype .rounds :
486341 return f"to_char({ value } ::timestamp({ coltype .precision } ), 'YYYY-mm-dd HH24:MI:SS.US')"
487342
@@ -490,7 +345,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
490345 f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
491346 )
492347
493- def normalize_number (self , value : str , coltype : ColType ) -> str :
348+ def normalize_number (self , value : str , coltype : NumericType ) -> str :
494349 return self .to_string (f"{ value } ::decimal(38, { coltype .precision } )" )
495350
496351
@@ -531,7 +386,7 @@ def _query(self, sql_code: str) -> list:
531386 def close (self ):
532387 self ._conn .close ()
533388
534- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
389+ def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
535390 # TODO
536391 if coltype .rounds :
537392 s = f"date_format(cast({ value } as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
@@ -540,7 +395,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
540395
541396 return f"RPAD(RPAD({ s } , { TIMESTAMP_PRECISION_POS + coltype .precision } , '.'), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
542397
543- def normalize_number (self , value : str , coltype : ColType ) -> str :
398+ def normalize_number (self , value : str , coltype : NumericType ) -> str :
544399 return self .to_string (f"cast({ value } as decimal(38,{ coltype .precision } ))" )
545400
546401 def select_table_schema (self , path : DbPath ) -> str :
@@ -554,11 +409,11 @@ def select_table_schema(self, path: DbPath) -> str:
554409 def _parse_type (
555410 self , col_name : str , type_repr : str , datetime_precision : int = None , numeric_precision : int = None
556411 ) -> ColType :
557- regexps = {
412+ timestamp_regexps = {
558413 r"timestamp\((\d)\)" : Timestamp ,
559414 r"timestamp\((\d)\) with time zone" : TimestampTZ ,
560415 }
561- for regexp , cls in regexps .items ():
416+ for regexp , cls in timestamp_regexps .items ():
562417 m = re .match (regexp + "$" , type_repr )
563418 if m :
564419 datetime_precision = int (m .group (1 ))
@@ -567,8 +422,8 @@ def _parse_type(
567422 rounds = False ,
568423 )
569424
570- regexps = {r"decimal\((\d+),(\d+)\)" : Decimal }
571- for regexp , cls in regexps .items ():
425+ number_regexps = {r"decimal\((\d+),(\d+)\)" : Decimal }
426+ for regexp , cls in number_regexps .items ():
572427 m = re .match (regexp + "$" , type_repr )
573428 if m :
574429 prec , scale = map (int , m .groups ())
@@ -632,14 +487,14 @@ def md5_to_int(self, s: str) -> str:
632487 def to_string (self , s : str ):
633488 return f"cast({ s } as char)"
634489
635- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
490+ def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
636491 if coltype .rounds :
637492 return self .to_string (f"cast( cast({ value } as datetime({ coltype .precision } )) as datetime(6))" )
638493
639494 s = self .to_string (f"cast({ value } as datetime(6))" )
640495 return f"RPAD(RPAD({ s } , { TIMESTAMP_PRECISION_POS + coltype .precision } , '.'), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
641496
642- def normalize_number (self , value : str , coltype : ColType ) -> str :
497+ def normalize_number (self , value : str , coltype : NumericType ) -> str :
643498 return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
644499
645500
@@ -685,10 +540,10 @@ def select_table_schema(self, path: DbPath) -> str:
685540 f" FROM USER_TAB_COLUMNS WHERE table_name = '{ table .upper ()} '"
686541 )
687542
688- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
543+ def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
689544 return f"to_char(cast({ value } as timestamp({ coltype .precision } )), 'YYYY-MM-DD HH24:MI:SS.FF6')"
690545
691- def normalize_number (self , value : str , coltype : ColType ) -> str :
546+ def normalize_number (self , value : str , coltype : NumericType ) -> str :
692547 # FM999.9990
693548 format_str = "FM" + "9" * (38 - coltype .precision )
694549 if coltype .precision :
@@ -749,7 +604,7 @@ class Redshift(PostgreSQL):
749604 def md5_to_int (self , s : str ) -> str :
750605 return f"strtol(substring(md5({ s } ), { 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS } ), 16)::decimal(38)"
751606
752- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
607+ def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
753608 if coltype .rounds :
754609 timestamp = f"{ value } ::timestamp(6)"
755610 # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds.
@@ -769,7 +624,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
769624 f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
770625 )
771626
772- def normalize_number (self , value : str , coltype : ColType ) -> str :
627+ def normalize_number (self , value : str , coltype : NumericType ) -> str :
773628 return self .to_string (f"{ value } ::decimal(38,{ coltype .precision } )" )
774629
775630 def select_table_schema (self , path : DbPath ) -> str :
@@ -870,7 +725,7 @@ def select_table_schema(self, path: DbPath) -> str:
870725 f"WHERE table_name = '{ table } ' AND table_schema = '{ schema } '"
871726 )
872727
873- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
728+ def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
874729 if coltype .rounds :
875730 timestamp = f"timestamp_micros(cast(round(unix_micros(cast({ value } as timestamp))/1000000, { coltype .precision } )*1000000 as int))"
876731 return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', { timestamp } )"
@@ -885,7 +740,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
885740 f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
886741 )
887742
888- def normalize_number (self , value : str , coltype : ColType ) -> str :
743+ def normalize_number (self , value : str , coltype : NumericType ) -> str :
889744 if isinstance (coltype , Integer ):
890745 return self .to_string (value )
891746 return f"format('%.{ coltype .precision } f', { value } )"
@@ -962,21 +817,21 @@ def select_table_schema(self, path: DbPath) -> str:
962817 schema , table = self ._normalize_table_path (path )
963818 return super ().select_table_schema ((schema , table ))
964819
965- def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
820+ def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
966821 if coltype .rounds :
967822 timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, { value } ::timestamp(9))/1000000000, { coltype .precision } ))"
968823 else :
969824 timestamp = f"cast({ value } as timestamp({ coltype .precision } ))"
970825
971826 return f"to_char({ timestamp } , 'YYYY-MM-DD HH24:MI:SS.FF6')"
972827
973- def normalize_number (self , value : str , coltype : ColType ) -> str :
828+ def normalize_number (self , value : str , coltype : NumericType ) -> str :
974829 return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
975830
976831
977832@dataclass
978833class MatchUriPath :
979- database_cls : type
834+ database_cls : Type [ Database ]
980835 params : List [str ]
981836 kwparams : List [str ] = []
982837 help_str : str
@@ -1027,7 +882,7 @@ def match_path(self, dsn):
1027882 "postgresql" : MatchUriPath (PostgreSQL , ["database?" ], help_str = "postgresql://<user>:<pass>@<host>/<database>" ),
1028883 "mysql" : MatchUriPath (MySQL , ["database?" ], help_str = "mysql://<user>:<pass>@<host>/<database>" ),
1029884 "oracle" : MatchUriPath (Oracle , ["database?" ], help_str = "oracle://<user>:<pass>@<host>/<database>" ),
1030- "mssql" : MatchUriPath (MsSQL , ["database?" ], help_str = "mssql://<user>:<pass>@<host>/<database>" ),
885+ # "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://<user>:<pass>@<host>/<database>"),
1031886 "redshift" : MatchUriPath (Redshift , ["database?" ], help_str = "redshift://<user>:<pass>@<host>/<database>" ),
1032887 "snowflake" : MatchUriPath (
1033888 Snowflake ,
@@ -1055,7 +910,6 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
1055910 Supported schemes:
1056911 - postgresql
1057912 - mysql
1058- - mssql
1059913 - oracle
1060914 - snowflake
1061915 - bigquery
0 commit comments