Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 8103432

Browse files
authored
Merge pull request #71 from datafold/fix_column_case
Now automatically fixing the column case using the schema.
2 parents 66248cf + 0af7568 commit 8103432

File tree

4 files changed

+59
-24
lines changed

4 files changed

+59
-24
lines changed

data_diff/__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
@click.option("-d", "--debug", is_flag=True, help="Print debug info")
5454
@click.option("-v", "--verbose", is_flag=True, help="Print extra info")
5555
@click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug")
56+
@click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.")
5657
@click.option(
5758
"-j",
5859
"--threads",
@@ -79,6 +80,7 @@ def main(
7980
verbose,
8081
interactive,
8182
threads,
83+
keep_column_case,
8284
):
8385
if limit and stats:
8486
print("Error: cannot specify a limit when using the -s/--stats switch")
@@ -119,6 +121,7 @@ def main(
119121
options = dict(
120122
min_update=max_age and parse_time_before_now(max_age),
121123
max_update=min_age and parse_time_before_now(min_age),
124+
case_sensitive=keep_column_case,
122125
)
123126
except ParseError as e:
124127
logging.error("Error while parsing age expression: %s" % e)

data_diff/diff_tables.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Provides classes for performing a table diff
22
"""
33

4+
from abc import ABC, abstractmethod
45
import time
56
from operator import attrgetter, methodcaller
67
from collections import defaultdict
7-
from typing import List, Tuple, Iterator, Optional, Mapping
8+
from typing import List, Tuple, Iterator, Optional
89
import logging
910
from concurrent.futures import ThreadPoolExecutor
1011

@@ -36,24 +37,47 @@ def parse_table_name(t):
3637
return tuple(t.split("."))
3738

3839

39-
class CaseInsensitiveDict(Mapping):
40-
def __init__(self, initial=()):
41-
self._dict = {k.lower(): v for k, v in dict(initial).items()}
40+
class Schema(ABC):
41+
@abstractmethod
42+
def get_key(self, key: str) -> str:
43+
...
4244

43-
def __setitem__(self, key, value):
44-
self._dict[key.lower()] = value
45+
@abstractmethod
46+
def __getitem__(self, key: str) -> str:
47+
...
4548

46-
def __getitem__(self, key):
47-
try:
48-
return self._dict[key.lower()]
49-
except KeyError:
50-
raise
49+
@abstractmethod
50+
def __setitem__(self, key: str, value):
51+
...
5152

52-
def __iter__(self):
53-
return iter(self._dict)
53+
@abstractmethod
54+
def __contains__(self, key: str) -> bool:
55+
...
5456

55-
def __len__(self):
56-
return len(self._dict)
57+
58+
class Schema_CaseSensitive(dict, Schema):
59+
def get_key(self, key):
60+
return key
61+
62+
63+
class Schema_CaseInsensitive(Schema):
64+
def __init__(self, initial):
65+
self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()}
66+
67+
def get_key(self, key: str) -> str:
68+
return self._dict[key.lower()][0]
69+
70+
def __getitem__(self, key: str) -> str:
71+
return self._dict[key.lower()][1]
72+
73+
def __setitem__(self, key: str, value):
74+
k = key.lower()
75+
if k in self._dict:
76+
key = self._dict[k][0]
77+
self._dict[k] = key, value
78+
79+
def __contains__(self, key):
80+
return key.lower() in self._dict
5781

5882

5983
@dataclass(frozen=False)
@@ -88,8 +112,8 @@ class TableSegment:
88112
min_update: DbTime = None
89113
max_update: DbTime = None
90114

91-
quote_columns: bool = True
92-
_schema: Mapping[str, ColType] = None
115+
case_sensitive: bool = True
116+
_schema: Schema = None
93117

94118
def __post_init__(self):
95119
if not self.update_column and (self.min_update or self.max_update):
@@ -110,17 +134,24 @@ def _update_column(self):
110134
return self._quote_column(self.update_column)
111135

112136
def _quote_column(self, c):
113-
if self.quote_columns:
114-
return self.database.quote(c)
115-
return c
137+
if self._schema:
138+
c = self._schema.get_key(c)
139+
return self.database.quote(c)
116140

117141
def with_schema(self) -> "TableSegment":
118142
"Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
119143
if self._schema:
120144
return self
121145
schema = self.database.query_table_schema(self.table_path)
122-
if not self.quote_columns:
123-
schema = CaseInsensitiveDict(schema)
146+
if self.case_sensitive:
147+
schema = Schema_CaseSensitive(schema)
148+
else:
149+
if len({k.lower() for k in schema}) < len(schema):
150+
logger.warn(
151+
f'Ambiguous schema for {self.database}:{".".join(self.table_path)} | Columns = {", ".join(list(schema))}'
152+
)
153+
logger.warn("We recommend to disable case-insensitivity (remove --any-case).")
154+
schema = Schema_CaseInsensitive(schema)
124155
return self.new(_schema=schema)
125156

126157
def _make_key_range(self):

tests/test_database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_md5_to_int(self):
1818

1919
self.assertEqual(str_to_checksum(str), self.mysql.query(query, int))
2020

21+
2122
class TestConnect(unittest.TestCase):
2223
def test_bad_uris(self):
2324
self.assertRaises(ValueError, connect_to_uri, "p")

tests/test_database_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
214214
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type});", None)
215215
_insert_to_table(dst_conn, dst_table, values_in_source)
216216

217-
self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), quote_columns=False)
218-
self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), quote_columns=False)
217+
self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False)
218+
self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False)
219219

220220
self.assertEqual(len(sample_values), self.table.count())
221221
self.assertEqual(len(sample_values), self.table2.count())

0 commit comments

Comments
 (0)