Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c585550

Browse files
authored
Merge pull request #223 from datafold/sep2
Small bugfixes and refactor
2 parents e6b9ffc + 8093bf7 commit c585550

File tree

6 files changed

+276
-278
lines changed

6 files changed

+276
-278
lines changed

data_diff/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,8 @@
22

33
from .tracking import disable_tracking
44
from .databases.connect import connect
5-
from .diff_tables import (
6-
TableSegment,
7-
TableDiffer,
8-
DEFAULT_BISECTION_THRESHOLD,
9-
DEFAULT_BISECTION_FACTOR,
10-
DbKey,
11-
DbTime,
12-
DbPath,
13-
)
5+
from .databases.database_types import DbKey, DbTime, DbPath
6+
from .diff_tables import TableSegment, TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
147

158

169
def connect_to_table(

data_diff/__main__.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,18 @@
55
import logging
66
from itertools import islice
77

8-
from data_diff.tracking import disable_tracking
8+
import rich
9+
import click
910

10-
from .utils import remove_password_from_url, safezip, match_like
1111

12-
from .diff_tables import (
13-
TableSegment,
14-
TableDiffer,
15-
DEFAULT_BISECTION_THRESHOLD,
16-
DEFAULT_BISECTION_FACTOR,
17-
create_schema,
18-
)
12+
from .utils import remove_password_from_url, safezip, match_like
13+
from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
14+
from .table_segment import create_schema, TableSegment
1915
from .databases.connect import connect
2016
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
2117
from .config import apply_config_from_file
18+
from .tracking import disable_tracking
2219

23-
import rich
24-
import click
2520

2621
LOG_FORMAT = "[%(asctime)s] %(levelname)s - %(message)s"
2722
DATE_FORMAT = "%H:%M:%S"

data_diff/diff_tables.py

Lines changed: 7 additions & 257 deletions
Original file line numberDiff line numberDiff line change
@@ -6,275 +6,24 @@
66
from numbers import Number
77
from operator import attrgetter, methodcaller
88
from collections import defaultdict
9-
from typing import List, Tuple, Iterator, Optional
9+
from typing import Tuple, Iterator, Optional
1010
import logging
1111
from concurrent.futures import ThreadPoolExecutor, as_completed
1212

1313
from runtype import dataclass
1414

15+
from .utils import safezip, run_as_daemon
16+
from .databases.database_types import IKey, NumericType, PrecisionType, StringType
17+
from .table_segment import TableSegment
1518
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
16-
from .sql import Select, Checksum, Compare, Count, TableName, Time, Value
17-
from .utils import (
18-
CaseAwareMapping,
19-
CaseInsensitiveDict,
20-
safezip,
21-
split_space,
22-
CaseSensitiveDict,
23-
ArithString,
24-
run_as_daemon,
25-
)
26-
from .databases.base import Database
27-
from .databases.database_types import (
28-
DbPath,
29-
DbKey,
30-
DbTime,
31-
IKey,
32-
Native_UUID,
33-
NumericType,
34-
PrecisionType,
35-
StringType,
36-
Schema,
37-
)
3819

3920
logger = logging.getLogger("diff_tables")
4021

41-
RECOMMENDED_CHECKSUM_DURATION = 10
4222
BENCHMARK = os.environ.get("BENCHMARK", False)
4323
DEFAULT_BISECTION_THRESHOLD = 1024 * 16
4424
DEFAULT_BISECTION_FACTOR = 32
4525

4626

47-
def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
48-
logger.debug(f"[{db.name}] Schema = {schema}")
49-
50-
if case_sensitive:
51-
return CaseSensitiveDict(schema)
52-
53-
if len({k.lower() for k in schema}) < len(schema):
54-
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
55-
logger.warning("We recommend to disable case-insensitivity (remove --any-case).")
56-
return CaseInsensitiveDict(schema)
57-
58-
59-
@dataclass(frozen=False)
60-
class TableSegment:
61-
"""Signifies a segment of rows (and selected columns) within a table
62-
63-
Parameters:
64-
database (Database): Database instance. See :meth:`connect`
65-
table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')`
66-
key_column (str): Name of the key column, which uniquely identifies each row (usually id)
67-
update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update)
68-
extra_columns (Tuple[str, ...], optional): Extra columns to compare
69-
min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment
70-
max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment
71-
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
72-
max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment
73-
where (str, optional): An additional 'where' expression to restrict the search space.
74-
75-
case_sensitive (bool): If false, the case of column names will adjust according to the schema. Default is true.
76-
77-
"""
78-
79-
# Location of table
80-
database: Database
81-
table_path: DbPath
82-
83-
# Columns
84-
key_column: str
85-
update_column: str = None
86-
extra_columns: Tuple[str, ...] = ()
87-
88-
# Restrict the segment
89-
min_key: DbKey = None
90-
max_key: DbKey = None
91-
min_update: DbTime = None
92-
max_update: DbTime = None
93-
94-
where: str = None
95-
case_sensitive: bool = True
96-
_schema: Schema = None
97-
98-
def __post_init__(self):
99-
if not self.update_column and (self.min_update or self.max_update):
100-
raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.")
101-
102-
if self.min_key is not None and self.max_key is not None and self.min_key >= self.max_key:
103-
raise ValueError(f"Error: min_key expected to be smaller than max_key! ({self.min_key} >= {self.max_key})")
104-
105-
if self.min_update is not None and self.max_update is not None and self.min_update >= self.max_update:
106-
raise ValueError(
107-
f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
108-
)
109-
110-
@property
111-
def _update_column(self):
112-
return self._quote_column(self.update_column)
113-
114-
def _quote_column(self, c: str) -> str:
115-
if self._schema:
116-
c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive.
117-
return self.database.quote(c)
118-
119-
def _normalize_column(self, name: str, template: str = None) -> str:
120-
if not self._schema:
121-
raise RuntimeError(
122-
"Cannot compile query when the schema is unknown. Please use TableSegment.with_schema()."
123-
)
124-
125-
col_type = self._schema[name]
126-
col = self._quote_column(name)
127-
128-
if isinstance(col_type, Native_UUID):
129-
# Normalize first, apply template after (for uuids)
130-
# Needed because min/max(uuid) fails in postgresql
131-
col = self.database.normalize_value_by_type(col, col_type)
132-
if template is not None:
133-
col = template % col # Apply template using Python's string formatting
134-
return col
135-
136-
# Apply template before normalizing (for ints)
137-
if template is not None:
138-
col = template % col # Apply template using Python's string formatting
139-
140-
return self.database.normalize_value_by_type(col, col_type)
141-
142-
def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
143-
schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns)
144-
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
145-
146-
def with_schema(self) -> "TableSegment":
147-
"Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
148-
if self._schema:
149-
return self
150-
151-
return self._with_raw_schema(self.database.query_table_schema(self.table_path))
152-
153-
def _make_key_range(self):
154-
if self.min_key is not None:
155-
yield Compare("<=", Value(self.min_key), self._quote_column(self.key_column))
156-
if self.max_key is not None:
157-
yield Compare("<", self._quote_column(self.key_column), Value(self.max_key))
158-
159-
def _make_update_range(self):
160-
if self.min_update is not None:
161-
yield Compare("<=", Time(self.min_update), self._update_column)
162-
if self.max_update is not None:
163-
yield Compare("<", self._update_column, Time(self.max_update))
164-
165-
def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None):
166-
if columns is None:
167-
columns = [self._normalize_column(self.key_column)]
168-
where = [
169-
*self._make_key_range(),
170-
*self._make_update_range(),
171-
*([] if where is None else [where]),
172-
*([] if self.where is None else [self.where]),
173-
]
174-
order_by = None if order_by is None else [order_by]
175-
return Select(
176-
table=table or TableName(self.table_path),
177-
where=where,
178-
columns=columns,
179-
group_by=group_by,
180-
order_by=order_by,
181-
)
182-
183-
def get_values(self) -> list:
184-
"Download all the relevant values of the segment from the database"
185-
select = self._make_select(columns=self._relevant_columns_repr)
186-
return self.database.query(select, List[Tuple])
187-
188-
def choose_checkpoints(self, count: int) -> List[DbKey]:
189-
"Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)"
190-
assert self.is_bounded
191-
if isinstance(self.min_key, ArithString):
192-
assert type(self.min_key) is type(self.max_key)
193-
checkpoints = split_space(self.min_key.int, self.max_key.int, count)
194-
return [self.min_key.new(int=i) for i in checkpoints]
195-
196-
return split_space(self.min_key, self.max_key, count)
197-
198-
def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]:
199-
"Split the current TableSegment to a bunch of smaller ones, separated by the given checkpoints"
200-
201-
if self.min_key and self.max_key:
202-
assert all(self.min_key <= c < self.max_key for c in checkpoints)
203-
checkpoints.sort()
204-
205-
# Calculate sub-segments
206-
positions = [self.min_key] + checkpoints + [self.max_key]
207-
ranges = list(zip(positions[:-1], positions[1:]))
208-
209-
# Create table segments
210-
tables = [self.new(min_key=s, max_key=e) for s, e in ranges]
211-
212-
return tables
213-
214-
def new(self, **kwargs) -> "TableSegment":
215-
"""Using new() creates a copy of the instance using 'replace()'"""
216-
return self.replace(**kwargs)
217-
218-
@property
219-
def _relevant_columns(self) -> List[str]:
220-
extras = list(self.extra_columns)
221-
222-
if self.update_column and self.update_column not in extras:
223-
extras = [self.update_column] + extras
224-
225-
return [self.key_column] + extras
226-
227-
@property
228-
def _relevant_columns_repr(self) -> List[str]:
229-
return [self._normalize_column(c) for c in self._relevant_columns]
230-
231-
def count(self) -> Tuple[int, int]:
232-
"""Count how many rows are in the segment, in one pass."""
233-
return self.database.query(self._make_select(columns=[Count()]), int)
234-
235-
def count_and_checksum(self) -> Tuple[int, int]:
236-
"""Count and checksum the rows in the segment, in one pass."""
237-
start = time.monotonic()
238-
count, checksum = self.database.query(
239-
self._make_select(columns=[Count(), Checksum(self._relevant_columns_repr)]), tuple
240-
)
241-
duration = time.monotonic() - start
242-
if duration > RECOMMENDED_CHECKSUM_DURATION:
243-
logger.warning(
244-
f"Checksum is taking longer than expected ({duration:.2f}s). "
245-
"We recommend increasing --bisection-factor or decreasing --threads."
246-
)
247-
248-
if count:
249-
assert checksum, (count, checksum)
250-
return count or 0, checksum if checksum is None else int(checksum)
251-
252-
def query_key_range(self) -> Tuple[int, int]:
253-
"""Query database for minimum and maximum key. This is used for setting the initial bounds."""
254-
# Normalizes the result (needed for UUIDs) after the min/max computation
255-
select = self._make_select(
256-
columns=[
257-
self._normalize_column(self.key_column, "min(%s)"),
258-
self._normalize_column(self.key_column, "max(%s)"),
259-
]
260-
)
261-
min_key, max_key = self.database.query(select, tuple)
262-
263-
if min_key is None or max_key is None:
264-
raise ValueError("Table appears to be empty")
265-
266-
return min_key, max_key
267-
268-
@property
269-
def is_bounded(self):
270-
return self.min_key is not None and self.max_key is not None
271-
272-
def approximate_size(self):
273-
if not self.is_bounded:
274-
raise RuntimeError("Cannot approximate the size of an unbounded segment. Must have min_key and max_key.")
275-
return self.max_key - self.min_key
276-
277-
27827
def diff_sets(a: set, b: set) -> Iterator:
27928
s1 = set(a)
28029
s2 = set(b)
@@ -346,6 +95,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
34695

34796
self.stats["diff_count"] = 0
34897
start = time.monotonic()
98+
error = None
34999
try:
350100

351101
# Query and validate schema
@@ -388,7 +138,6 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
388138
post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)]
389139
yield from self._bisect_and_diff_tables(*post_tables)
390140

391-
error = None
392141
except BaseException as e: # Catch KeyboardInterrupt too
393142
error = e
394143
finally:
@@ -559,7 +308,8 @@ def _threaded_call(self, func, iterable):
559308

560309
def _thread_as_completed(self, func, iterable):
561310
if not self.threaded:
562-
return map(func, iterable)
311+
yield from map(func, iterable)
312+
return
563313

564314
with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
565315
futures = [task_pool.submit(func, item) for item in iterable]

0 commit comments

Comments
 (0)