11"""Provides classes for performing a table diff
22"""
33
4+ from abc import ABC , abstractmethod
45import time
56from operator import attrgetter , methodcaller
67from collections import defaultdict
7- from typing import List , Tuple , Iterator , Optional , Mapping
8+ from typing import List , Tuple , Iterator , Optional
89import logging
910from concurrent .futures import ThreadPoolExecutor
1011
@@ -36,24 +37,47 @@ def parse_table_name(t):
3637 return tuple (t .split ("." ))
3738
3839
39- class CaseInsensitiveDict (Mapping ):
40- def __init__ (self , initial = ()):
41- self ._dict = {k .lower (): v for k , v in dict (initial ).items ()}
40+ class Schema (ABC ):
41+ @abstractmethod
42+ def get_key (self , key : str ) -> str :
43+ ...
4244
43- def __setitem__ (self , key , value ):
44- self ._dict [key .lower ()] = value
45+ @abstractmethod
46+ def __getitem__ (self , key : str ) -> str :
47+ ...
4548
46- def __getitem__ (self , key ):
47- try :
48- return self ._dict [key .lower ()]
49- except KeyError :
50- raise
49+ @abstractmethod
50+ def __setitem__ (self , key : str , value ):
51+ ...
5152
52- def __iter__ (self ):
53- return iter (self ._dict )
53+ @abstractmethod
54+ def __contains__ (self , key : str ) -> bool :
55+ ...
5456
55- def __len__ (self ):
56- return len (self ._dict )
57+
58+ class Schema_CaseSensitive (dict , Schema ):
59+ def get_key (self , key ):
60+ return key
61+
62+
63+ class Schema_CaseInsensitive (Schema ):
64+ def __init__ (self , initial ):
65+ self ._dict = {k .lower (): (k , v ) for k , v in dict (initial ).items ()}
66+
67+ def get_key (self , key : str ) -> str :
68+ return self ._dict [key .lower ()][0 ]
69+
70+ def __getitem__ (self , key : str ) -> str :
71+ return self ._dict [key .lower ()][1 ]
72+
73+ def __setitem__ (self , key : str , value ):
74+ k = key .lower ()
75+ if k in self ._dict :
76+ key = self ._dict [k ][0 ]
77+ self ._dict [k ] = key , value
78+
79+ def __contains__ (self , key ):
80+ return key .lower () in self ._dict
5781
5882
5983@dataclass (frozen = False )
@@ -88,8 +112,8 @@ class TableSegment:
88112 min_update : DbTime = None
89113 max_update : DbTime = None
90114
91- quote_columns : bool = True
92- _schema : Mapping [ str , ColType ] = None
115+ case_sensitive : bool = True
116+ _schema : Schema = None
93117
94118 def __post_init__ (self ):
95119 if not self .update_column and (self .min_update or self .max_update ):
@@ -110,17 +134,24 @@ def _update_column(self):
110134 return self ._quote_column (self .update_column )
111135
112136 def _quote_column (self , c ):
113- if self .quote_columns :
114- return self .database . quote (c )
115- return c
137+ if self ._schema :
138+ c = self ._schema . get_key (c )
139+ return self . database . quote ( c )
116140
117141 def with_schema (self ) -> "TableSegment" :
118142 "Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
119143 if self ._schema :
120144 return self
121145 schema = self .database .query_table_schema (self .table_path )
122- if not self .quote_columns :
123- schema = CaseInsensitiveDict (schema )
146+ if self .case_sensitive :
147+ schema = Schema_CaseSensitive (schema )
148+ else :
149+ if len ({k .lower () for k in schema }) < len (schema ):
150+ logger .warn (
151+ f'Ambiguous schema for { self .database } :{ "." .join (self .table_path )} | Columns = { ", " .join (list (schema ))} '
152+ )
153+ logger .warn ("We recommend to disable case-insensitivity (remove --any-case)." )
154+ schema = Schema_CaseInsensitive (schema )
124155 return self .new (_schema = schema )
125156
126157 def _make_key_range (self ):
0 commit comments