|
6 | 6 | from numbers import Number |
7 | 7 | from operator import attrgetter, methodcaller |
8 | 8 | from collections import defaultdict |
9 | | -from typing import List, Tuple, Iterator, Optional |
| 9 | +from typing import Tuple, Iterator, Optional |
10 | 10 | import logging |
11 | 11 | from concurrent.futures import ThreadPoolExecutor, as_completed |
12 | 12 |
|
13 | 13 | from runtype import dataclass |
14 | 14 |
|
| 15 | +from .utils import safezip, run_as_daemon |
| 16 | +from .databases.database_types import IKey, NumericType, PrecisionType, StringType |
| 17 | +from .table_segment import TableSegment |
15 | 18 | from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled |
16 | | -from .sql import Select, Checksum, Compare, Count, TableName, Time, Value |
17 | | -from .utils import ( |
18 | | - CaseAwareMapping, |
19 | | - CaseInsensitiveDict, |
20 | | - safezip, |
21 | | - split_space, |
22 | | - CaseSensitiveDict, |
23 | | - ArithString, |
24 | | - run_as_daemon, |
25 | | -) |
26 | | -from .databases.base import Database |
27 | | -from .databases.database_types import ( |
28 | | - DbPath, |
29 | | - DbKey, |
30 | | - DbTime, |
31 | | - IKey, |
32 | | - Native_UUID, |
33 | | - NumericType, |
34 | | - PrecisionType, |
35 | | - StringType, |
36 | | - Schema, |
37 | | -) |
38 | 19 |
|
39 | 20 | logger = logging.getLogger("diff_tables") |
40 | 21 |
|
41 | | -RECOMMENDED_CHECKSUM_DURATION = 10 |
42 | 22 | BENCHMARK = os.environ.get("BENCHMARK", False) |
43 | 23 | DEFAULT_BISECTION_THRESHOLD = 1024 * 16 |
44 | 24 | DEFAULT_BISECTION_FACTOR = 32 |
45 | 25 |
|
46 | 26 |
|
47 | | -def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: |
48 | | - logger.debug(f"[{db.name}] Schema = {schema}") |
49 | | - |
50 | | - if case_sensitive: |
51 | | - return CaseSensitiveDict(schema) |
52 | | - |
53 | | - if len({k.lower() for k in schema}) < len(schema): |
54 | | - logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') |
55 | | - logger.warning("We recommend to disable case-insensitivity (remove --any-case).") |
56 | | - return CaseInsensitiveDict(schema) |
57 | | - |
58 | | - |
59 | | -@dataclass(frozen=False) |
60 | | -class TableSegment: |
61 | | - """Signifies a segment of rows (and selected columns) within a table |
62 | | -
|
63 | | - Parameters: |
64 | | - database (Database): Database instance. See :meth:`connect` |
65 | | - table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` |
66 | | - key_column (str): Name of the key column, which uniquely identifies each row (usually id) |
67 | | - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) |
68 | | - extra_columns (Tuple[str, ...], optional): Extra columns to compare |
69 | | - min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment |
70 | | - max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment |
71 | | - min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment |
72 | | - max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment |
73 | | - where (str, optional): An additional 'where' expression to restrict the search space. |
74 | | -
|
75 | | - case_sensitive (bool): If false, the case of column names will adjust according to the schema. Default is true. |
76 | | -
|
77 | | - """ |
78 | | - |
79 | | - # Location of table |
80 | | - database: Database |
81 | | - table_path: DbPath |
82 | | - |
83 | | - # Columns |
84 | | - key_column: str |
85 | | - update_column: str = None |
86 | | - extra_columns: Tuple[str, ...] = () |
87 | | - |
88 | | - # Restrict the segment |
89 | | - min_key: DbKey = None |
90 | | - max_key: DbKey = None |
91 | | - min_update: DbTime = None |
92 | | - max_update: DbTime = None |
93 | | - |
94 | | - where: str = None |
95 | | - case_sensitive: bool = True |
96 | | - _schema: Schema = None |
97 | | - |
98 | | - def __post_init__(self): |
99 | | - if not self.update_column and (self.min_update or self.max_update): |
100 | | - raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.") |
101 | | - |
102 | | - if self.min_key is not None and self.max_key is not None and self.min_key >= self.max_key: |
103 | | - raise ValueError(f"Error: min_key expected to be smaller than max_key! ({self.min_key} >= {self.max_key})") |
104 | | - |
105 | | - if self.min_update is not None and self.max_update is not None and self.min_update >= self.max_update: |
106 | | - raise ValueError( |
107 | | - f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})" |
108 | | - ) |
109 | | - |
110 | | - @property |
111 | | - def _update_column(self): |
112 | | - return self._quote_column(self.update_column) |
113 | | - |
114 | | - def _quote_column(self, c: str) -> str: |
115 | | - if self._schema: |
116 | | - c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive. |
117 | | - return self.database.quote(c) |
118 | | - |
119 | | - def _normalize_column(self, name: str, template: str = None) -> str: |
120 | | - if not self._schema: |
121 | | - raise RuntimeError( |
122 | | - "Cannot compile query when the schema is unknown. Please use TableSegment.with_schema()." |
123 | | - ) |
124 | | - |
125 | | - col_type = self._schema[name] |
126 | | - col = self._quote_column(name) |
127 | | - |
128 | | - if isinstance(col_type, Native_UUID): |
129 | | - # Normalize first, apply template after (for uuids) |
130 | | - # Needed because min/max(uuid) fails in postgresql |
131 | | - col = self.database.normalize_value_by_type(col, col_type) |
132 | | - if template is not None: |
133 | | - col = template % col # Apply template using Python's string formatting |
134 | | - return col |
135 | | - |
136 | | - # Apply template before normalizing (for ints) |
137 | | - if template is not None: |
138 | | - col = template % col # Apply template using Python's string formatting |
139 | | - |
140 | | - return self.database.normalize_value_by_type(col, col_type) |
141 | | - |
142 | | - def _with_raw_schema(self, raw_schema: dict) -> "TableSegment": |
143 | | - schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns) |
144 | | - return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) |
145 | | - |
146 | | - def with_schema(self) -> "TableSegment": |
147 | | - "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema." |
148 | | - if self._schema: |
149 | | - return self |
150 | | - |
151 | | - return self._with_raw_schema(self.database.query_table_schema(self.table_path)) |
152 | | - |
153 | | - def _make_key_range(self): |
154 | | - if self.min_key is not None: |
155 | | - yield Compare("<=", Value(self.min_key), self._quote_column(self.key_column)) |
156 | | - if self.max_key is not None: |
157 | | - yield Compare("<", self._quote_column(self.key_column), Value(self.max_key)) |
158 | | - |
159 | | - def _make_update_range(self): |
160 | | - if self.min_update is not None: |
161 | | - yield Compare("<=", Time(self.min_update), self._update_column) |
162 | | - if self.max_update is not None: |
163 | | - yield Compare("<", self._update_column, Time(self.max_update)) |
164 | | - |
165 | | - def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None): |
166 | | - if columns is None: |
167 | | - columns = [self._normalize_column(self.key_column)] |
168 | | - where = [ |
169 | | - *self._make_key_range(), |
170 | | - *self._make_update_range(), |
171 | | - *([] if where is None else [where]), |
172 | | - *([] if self.where is None else [self.where]), |
173 | | - ] |
174 | | - order_by = None if order_by is None else [order_by] |
175 | | - return Select( |
176 | | - table=table or TableName(self.table_path), |
177 | | - where=where, |
178 | | - columns=columns, |
179 | | - group_by=group_by, |
180 | | - order_by=order_by, |
181 | | - ) |
182 | | - |
183 | | - def get_values(self) -> list: |
184 | | - "Download all the relevant values of the segment from the database" |
185 | | - select = self._make_select(columns=self._relevant_columns_repr) |
186 | | - return self.database.query(select, List[Tuple]) |
187 | | - |
188 | | - def choose_checkpoints(self, count: int) -> List[DbKey]: |
189 | | - "Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)" |
190 | | - assert self.is_bounded |
191 | | - if isinstance(self.min_key, ArithString): |
192 | | - assert type(self.min_key) is type(self.max_key) |
193 | | - checkpoints = split_space(self.min_key.int, self.max_key.int, count) |
194 | | - return [self.min_key.new(int=i) for i in checkpoints] |
195 | | - |
196 | | - return split_space(self.min_key, self.max_key, count) |
197 | | - |
198 | | - def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]: |
199 | | - "Split the current TableSegment to a bunch of smaller ones, separated by the given checkpoints" |
200 | | - |
201 | | - if self.min_key and self.max_key: |
202 | | - assert all(self.min_key <= c < self.max_key for c in checkpoints) |
203 | | - checkpoints.sort() |
204 | | - |
205 | | - # Calculate sub-segments |
206 | | - positions = [self.min_key] + checkpoints + [self.max_key] |
207 | | - ranges = list(zip(positions[:-1], positions[1:])) |
208 | | - |
209 | | - # Create table segments |
210 | | - tables = [self.new(min_key=s, max_key=e) for s, e in ranges] |
211 | | - |
212 | | - return tables |
213 | | - |
214 | | - def new(self, **kwargs) -> "TableSegment": |
215 | | - """Using new() creates a copy of the instance using 'replace()'""" |
216 | | - return self.replace(**kwargs) |
217 | | - |
218 | | - @property |
219 | | - def _relevant_columns(self) -> List[str]: |
220 | | - extras = list(self.extra_columns) |
221 | | - |
222 | | - if self.update_column and self.update_column not in extras: |
223 | | - extras = [self.update_column] + extras |
224 | | - |
225 | | - return [self.key_column] + extras |
226 | | - |
227 | | - @property |
228 | | - def _relevant_columns_repr(self) -> List[str]: |
229 | | - return [self._normalize_column(c) for c in self._relevant_columns] |
230 | | - |
231 | | - def count(self) -> Tuple[int, int]: |
232 | | - """Count how many rows are in the segment, in one pass.""" |
233 | | - return self.database.query(self._make_select(columns=[Count()]), int) |
234 | | - |
235 | | - def count_and_checksum(self) -> Tuple[int, int]: |
236 | | - """Count and checksum the rows in the segment, in one pass.""" |
237 | | - start = time.monotonic() |
238 | | - count, checksum = self.database.query( |
239 | | - self._make_select(columns=[Count(), Checksum(self._relevant_columns_repr)]), tuple |
240 | | - ) |
241 | | - duration = time.monotonic() - start |
242 | | - if duration > RECOMMENDED_CHECKSUM_DURATION: |
243 | | - logger.warning( |
244 | | - f"Checksum is taking longer than expected ({duration:.2f}s). " |
245 | | - "We recommend increasing --bisection-factor or decreasing --threads." |
246 | | - ) |
247 | | - |
248 | | - if count: |
249 | | - assert checksum, (count, checksum) |
250 | | - return count or 0, checksum if checksum is None else int(checksum) |
251 | | - |
252 | | - def query_key_range(self) -> Tuple[int, int]: |
253 | | - """Query database for minimum and maximum key. This is used for setting the initial bounds.""" |
254 | | - # Normalizes the result (needed for UUIDs) after the min/max computation |
255 | | - select = self._make_select( |
256 | | - columns=[ |
257 | | - self._normalize_column(self.key_column, "min(%s)"), |
258 | | - self._normalize_column(self.key_column, "max(%s)"), |
259 | | - ] |
260 | | - ) |
261 | | - min_key, max_key = self.database.query(select, tuple) |
262 | | - |
263 | | - if min_key is None or max_key is None: |
264 | | - raise ValueError("Table appears to be empty") |
265 | | - |
266 | | - return min_key, max_key |
267 | | - |
268 | | - @property |
269 | | - def is_bounded(self): |
270 | | - return self.min_key is not None and self.max_key is not None |
271 | | - |
272 | | - def approximate_size(self): |
273 | | - if not self.is_bounded: |
274 | | - raise RuntimeError("Cannot approximate the size of an unbounded segment. Must have min_key and max_key.") |
275 | | - return self.max_key - self.min_key |
276 | | - |
277 | | - |
278 | 27 | def diff_sets(a: set, b: set) -> Iterator: |
279 | 28 | s1 = set(a) |
280 | 29 | s2 = set(b) |
@@ -346,6 +95,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: |
346 | 95 |
|
347 | 96 | self.stats["diff_count"] = 0 |
348 | 97 | start = time.monotonic() |
| 98 | + error = None |
349 | 99 | try: |
350 | 100 |
|
351 | 101 | # Query and validate schema |
@@ -388,7 +138,6 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: |
388 | 138 | post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] |
389 | 139 | yield from self._bisect_and_diff_tables(*post_tables) |
390 | 140 |
|
391 | | - error = None |
392 | 141 | except BaseException as e: # Catch KeyboardInterrupt too |
393 | 142 | error = e |
394 | 143 | finally: |
@@ -559,7 +308,8 @@ def _threaded_call(self, func, iterable): |
559 | 308 |
|
560 | 309 | def _thread_as_completed(self, func, iterable): |
561 | 310 | if not self.threaded: |
562 | | - return map(func, iterable) |
| 311 | + yield from map(func, iterable) |
| 312 | + return |
563 | 313 |
|
564 | 314 | with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: |
565 | 315 | futures = [task_pool.submit(func, item) for item in iterable] |
|
0 commit comments