11"""Provides classes for performing a table diff
22"""
33
4+ from abc import ABC , abstractmethod
45import time
56from operator import attrgetter , methodcaller
67from collections import defaultdict
7- from typing import List , Tuple , Iterator , Optional , Mapping
8+ from typing import List , Tuple , Iterator , Optional
89import logging
910from concurrent .futures import ThreadPoolExecutor
1011
@@ -36,24 +37,44 @@ def parse_table_name(t):
3637 return tuple (t .split ("." ))
3738
3839
39- class CaseInsensitiveDict (Mapping ):
40- def __init__ (self , initial = ()):
41- self ._dict = {k .lower (): v for k , v in dict (initial ).items ()}
40+ class Schema (ABC ):
41+ @abstractmethod
42+ def get_key (self , key : str ) -> str :
43+ ...
4244
43- def __setitem__ (self , key , value ):
44- self ._dict [key .lower ()] = value
45+ @abstractmethod
46+ def __getitem__ (self , key : str ) -> str :
47+ ...
4548
46- def __getitem__ (self , key ):
47- try :
48- return self ._dict [key .lower ()]
49- except KeyError :
50- raise
49+ @abstractmethod
50+ def __setitem__ (self , key : str , value ):
51+ ...
5152
52- def __iter__ (self ):
53- return iter (self ._dict )
53+ @abstractmethod
54+ def __contains__ (self , key : str ) -> bool :
55+ ...
5456
55- def __len__ (self ):
56- return len (self ._dict )
57+
58+ class Schema_CaseSensitive (dict , Schema ):
59+ def get_key (self , key ):
60+ return key
61+
62+
63+ class Schema_CaseInsensitive (Schema ):
64+ def __init__ (self , initial ):
65+ self ._dict = {k .lower (): (k , v ) for k , v in dict (initial ).items ()}
66+
67+ def get_key (self , key : str ) -> str :
68+ return self ._dict [key .lower ()][0 ]
69+
70+ def __getitem__ (self , key : str ) -> str :
71+ return self ._dict [key .lower ()][1 ]
72+
73+ def __setitem__ (self , key : str , value ):
74+ self ._dict [key .lower ()] = key , value
75+
76+ def __contains__ (self , key ):
77+ return key .lower () in self ._dict
5778
5879
5980@dataclass (frozen = False )
@@ -88,8 +109,8 @@ class TableSegment:
88109 min_update : DbTime = None
89110 max_update : DbTime = None
90111
91- quote_columns : bool = True
92- _schema : Mapping [ str , ColType ] = None
112+ case_sensitive : bool = True
113+ _schema : Schema = None
93114
94115 def __post_init__ (self ):
95116 if not self .update_column and (self .min_update or self .max_update ):
@@ -110,17 +131,24 @@ def _update_column(self):
110131 return self ._quote_column (self .update_column )
111132
112133 def _quote_column (self , c ):
113- if self .quote_columns :
114- return self .database . quote (c )
115- return c
134+ if self ._schema :
135+ c = self ._schema . get_key (c )
136+ return self . database . quote ( c )
116137
117138 def with_schema (self ) -> "TableSegment" :
118139 "Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
119140 if self ._schema :
120141 return self
121142 schema = self .database .query_table_schema (self .table_path )
122- if not self .quote_columns :
123- schema = CaseInsensitiveDict (schema )
143+ if self .case_sensitive :
144+ schema = Schema_CaseSensitive (schema )
145+ else :
146+ if len ({k .lower () for k in schema }) < len (schema ):
147+ logger .warn (
148+ f'Ambiguous schema for { self .database } :{ "." .join (self .table_path )} | Columns = { ", " .join (list (schema ))} '
149+ )
150+ logger .warn ("We recommend to disable case-insensitivity (remove --any-case)." )
151+ schema = Schema_CaseInsensitive (schema )
124152 return self .new (_schema = schema )
125153
126154 def _make_key_range (self ):
0 commit comments