Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 57cb682

Browse files
committed
Bugfix in TableSegment: Sampling now respects the 'where' clause (issue #221)
1 parent 8093bf7 commit 57cb682

File tree

5 files changed

+22
-9
lines changed

5 files changed

+22
-9
lines changed

data_diff/databases/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,25 +187,28 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
187187
assert len(d) == len(rows)
188188
return d
189189

190-
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
190+
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None):
191191
accept = {i.lower() for i in filter_columns}
192192

193193
col_dict = {row[0]: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept}
194194

195-
self._refine_coltypes(path, col_dict)
195+
self._refine_coltypes(path, col_dict, where)
196196

197197
# Return a dict of form {name: type} after normalization
198198
return col_dict
199199

200-
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
201-
"Refine the types in the column dict, by querying the database for a sample of their values"
200+
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None):
201+
"""Refine the types in the column dict, by querying the database for a sample of their values
202+
203+
'where' restricts the rows to be sampled.
204+
"""
202205

203206
text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
204207
if not text_columns:
205208
return
206209

207210
fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
208-
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
211+
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16, where=where and [where]), list)
209212
if not samples_by_row:
210213
raise ValueError(f"Table {table_path} is empty.")
211214

data_diff/databases/database_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
177177
...
178178

179179
@abstractmethod
180-
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
180+
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None):
181181
"""Process the result of query_table_schema().
182182
183183
Done in a separate step, to minimize the amount of processed columns.

data_diff/databases/databricks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
8383
assert len(d) == len(rows)
8484
return d
8585

86-
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
86+
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None):
8787
accept = {i.lower() for i in filter_columns}
8888
rows = [row for name, row in raw_schema.items() if name.lower() in accept]
8989

@@ -115,7 +115,7 @@ def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filt
115115

116116
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in resulted_rows}
117117

118-
self._refine_coltypes(path, col_dict)
118+
self._refine_coltypes(path, col_dict, where)
119119
return col_dict
120120

121121
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:

data_diff/table_segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _normalize_column(self, name: str, template: str = None) -> str:
111111
return self.database.normalize_value_by_type(col, col_type)
112112

113113
def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
114-
schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns)
114+
schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns, self.where)
115115
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
116116

117117
def with_schema(self) -> "TableSegment":

tests/test_diff_tables.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,16 @@ def test_string_keys(self):
443443

444444
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
445445

446+
def test_where_sampling(self):
447+
a = self.a.replace(where='1=1')
448+
449+
differ = TableDiffer()
450+
diff = list(differ.diff_tables(a, self.b))
451+
self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))])
452+
453+
a_empty = self.a.replace(where='1=0')
454+
self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b))
455+
446456

447457
@test_per_database
448458
class TestAlphanumericKeys(TestPerDatabase):

0 commit comments

Comments
 (0)