Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c43f006

Browse files
committed
Fixed support for diffing columns of different names (Issue #229)
1 parent 0cd0c40 commit c43f006

File tree

2 files changed

+70
-33
lines changed

2 files changed

+70
-33
lines changed

data_diff/diff_tables.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,42 +172,42 @@ def _parse_key_range_result(self, key_type, key_range):
172172
raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e
173173

174174
def _validate_and_adjust_columns(self, table1, table2):
175-
for c in table1._relevant_columns:
176-
if c not in table1._schema:
175+
for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns):
176+
if c1 not in table1._schema:
177177
raise ValueError(f"Column '{c}' not found in schema for table {table1}")
178-
if c not in table2._schema:
178+
if c2 not in table2._schema:
179179
raise ValueError(f"Column '{c}' not found in schema for table {table2}")
180180

181181
# Update schemas to minimal mutual precision
182-
col1 = table1._schema[c]
183-
col2 = table2._schema[c]
182+
col1 = table1._schema[c1]
183+
col2 = table2._schema[c2]
184184
if isinstance(col1, PrecisionType):
185185
if not isinstance(col2, PrecisionType):
186-
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
186+
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
187187

188188
lowest = min(col1, col2, key=attrgetter("precision"))
189189

190190
if col1.precision != col2.precision:
191-
logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}")
191+
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
192192

193-
table1._schema[c] = col1.replace(precision=lowest.precision, rounds=lowest.rounds)
194-
table2._schema[c] = col2.replace(precision=lowest.precision, rounds=lowest.rounds)
193+
table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds)
194+
table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds)
195195

196196
elif isinstance(col1, NumericType):
197197
if not isinstance(col2, NumericType):
198-
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
198+
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
199199

200200
lowest = min(col1, col2, key=attrgetter("precision"))
201201

202202
if col1.precision != col2.precision:
203-
logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}")
203+
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
204204

205-
table1._schema[c] = col1.replace(precision=lowest.precision)
206-
table2._schema[c] = col2.replace(precision=lowest.precision)
205+
table1._schema[c1] = col1.replace(precision=lowest.precision)
206+
table2._schema[c2] = col2.replace(precision=lowest.precision)
207207

208208
elif isinstance(col1, StringType):
209209
if not isinstance(col2, StringType):
210-
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
210+
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
211211

212212
for t in [table1, table2]:
213213
for c in t._relevant_columns:

tests/test_diff_tables.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -234,25 +234,7 @@ def setUp(self):
234234
f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)",
235235
None,
236236
)
237-
# self.preql(
238-
# f"""
239-
# table {self.table_src_name} {{
240-
# userid: int
241-
# movieid: int
242-
# rating: float
243-
# timestamp: timestamp
244-
# }}
245-
246-
# table {self.table_dst_name} {{
247-
# userid: int
248-
# movieid: int
249-
# rating: float
250-
# timestamp: timestamp
251-
# }}
252-
# commit()
253-
# """
254-
# )
255-
self.preql.commit()
237+
_commit(self.connection)
256238

257239
self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
258240
self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False)
@@ -402,6 +384,61 @@ def test_diff_sorted_by_key(self):
402384
self.assertEqual(expected, diff)
403385

404386

387+
@test_per_database
388+
class TestDiffTables2(TestPerDatabase):
389+
390+
def test_diff_column_names(self):
391+
float_type = _get_float_type(self.connection)
392+
393+
self.connection.query(
394+
f"create table {self.table_src}(id int, rating {float_type}, timestamp timestamp)",
395+
None,
396+
)
397+
self.connection.query(
398+
f"create table {self.table_dst}(id2 int, rating2 {float_type}, timestamp2 timestamp)",
399+
None,
400+
)
401+
_commit(self.connection)
402+
403+
time = "2022-01-01 00:00:00"
404+
time2 = "2021-01-01 00:00:00"
405+
406+
time_str = f"timestamp '{time}'"
407+
time_str2 = f"timestamp '{time2}'"
408+
_insert_rows(
409+
self.connection,
410+
self.table_src,
411+
["id", "rating", "timestamp"],
412+
[
413+
[1, 9, time_str],
414+
[2, 9, time_str2],
415+
[3, 9, time_str],
416+
[4, 9, time_str2],
417+
[5, 9, time_str],
418+
],
419+
)
420+
421+
_insert_rows(
422+
self.connection,
423+
self.table_dst,
424+
["id2", "rating2", "timestamp2"],
425+
[
426+
[1, 9, time_str],
427+
[2, 9, time_str2],
428+
[3, 9, time_str],
429+
[4, 9, time_str2],
430+
[5, 9, time_str],
431+
],
432+
)
433+
434+
table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
435+
table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False)
436+
437+
differ = TableDiffer()
438+
diff = list(differ.diff_tables(table1, table2))
439+
assert diff == []
440+
441+
405442
@test_per_database
406443
class TestUUIDs(TestPerDatabase):
407444
def setUp(self):

0 commit comments

Comments
 (0)