Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 6fb64bc

Browse files
author
yeonjun kim
authored
Merge branch 'master' into add_mysql_type
2 parents b6387fb + 32d8bf0 commit 6fb64bc

File tree

13 files changed

+60
-97
lines changed

13 files changed

+60
-97
lines changed

data_diff/cloud/datafold_api.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22
import dataclasses
33
import enum
44
import time
5-
from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple
5+
from typing import Any, Dict, List, Optional, Type, Tuple
66

77
import pydantic
88
import requests
9+
from typing_extensions import Self
910

1011
from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError
1112

1213
from ..utils import getLogger
1314

1415
logger = getLogger(__name__)
1516

16-
Self = TypeVar("Self", bound=pydantic.BaseModel)
17-
1817

1918
class TestDataSourceStatus(str, enum.Enum):
2019
SUCCESS = "ok"
@@ -30,7 +29,7 @@ class TCloudApiDataSourceSchema(pydantic.BaseModel):
3029
secret: List[str]
3130

3231
@classmethod
33-
def from_orm(cls: Type[Self], obj: Any) -> Self:
32+
def from_orm(cls, obj: Any) -> Self:
3433
data_source_types_required_parameters = {
3534
"bigquery": ["projectId", "jsonKeyFile", "location"],
3635
"databricks": ["host", "http_password", "database", "http_path"],
@@ -154,7 +153,7 @@ class TCloudApiDataDiffSummaryResult(pydantic.BaseModel):
154153
dependencies: Optional[Dict[str, Any]]
155154

156155
@classmethod
157-
def from_orm(cls: Type[Self], obj: Any) -> Self:
156+
def from_orm(cls, obj: Any) -> Self:
158157
pks = TSummaryResultPrimaryKeyStats(**obj["pks"]) if "pks" in obj else None
159158
values = TSummaryResultValueStats(**obj["values"]) if "values" in obj else None
160159
deps = obj["deps"] if "deps" in obj else None

data_diff/sqeleton/abcs/database_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import decimal
22
from abc import ABC, abstractmethod
3-
from typing import Sequence, Optional, Tuple, Union, Dict, List
3+
from typing import Sequence, Optional, Tuple, Type, Union, Dict, List
44
from datetime import datetime
55

66
from runtype import dataclass
7+
from typing_extensions import Self
78

8-
from ..utils import ArithAlphanumeric, ArithUUID, Self, Unknown
9+
from ..utils import ArithAlphanumeric, ArithUUID, Unknown
910

1011

1112
DbPath = Tuple[str, ...]

data_diff/sqeleton/bound_exprs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Union, TYPE_CHECKING
66

77
from runtype import dataclass
8+
from typing_extensions import Self
89

910
from .abcs import AbstractDatabase, AbstractCompiler
1011
from .queries.ast_classes import ExprNode, ITable, TablePath, Compilable
@@ -52,11 +53,11 @@ class BoundTable(BoundNode): # ITable
5253
database: AbstractDatabase
5354
node: TablePath
5455

55-
def with_schema(self, schema):
56+
def with_schema(self, schema) -> Self:
5657
table_path = self.node.replace(schema=schema)
5758
return self.replace(node=table_path)
5859

59-
def query_schema(self, *, columns=None, where=None, case_sensitive=True):
60+
def query_schema(self, *, columns=None, where=None, case_sensitive=True) -> Self:
6061
table_path = self.node
6162

6263
if table_path.schema:

data_diff/sqeleton/databases/_connect.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import Type, Optional, Union, Dict
1+
from typing import Hashable, MutableMapping, Type, Optional, Union, Dict
22
from itertools import zip_longest
33
from contextlib import suppress
4+
import weakref
45
import dsnparse
56
import toml
67

78
from runtype import dataclass
9+
from typing_extensions import Self
810

911
from ..abcs.mixins import AbstractMixin
10-
from ..utils import WeakCache, Self
1112
from .base import Database, ThreadedDatabase
1213
from .postgresql import PostgreSQL
1314
from .mysql import MySQL
@@ -93,13 +94,14 @@ def match_path(self, dsn):
9394

9495
class Connect:
9596
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97+
conn_cache: MutableMapping[Hashable, Database]
9698

9799
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
98100
self.database_by_scheme = database_by_scheme
99101
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
100-
self.conn_cache = WeakCache()
102+
self.conn_cache = weakref.WeakValueDictionary()
101103

102-
def for_databases(self, *dbs):
104+
def for_databases(self, *dbs) -> Self:
103105
database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
104106
return type(self)(database_by_scheme)
105107

@@ -262,9 +264,10 @@ def __call__(
262264
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
263265
<data_diff.sqeleton.databases.mysql.MySQL object at ...>
264266
"""
267+
cache_key = self.__make_cache_key(db_conf)
265268
if shared:
266269
with suppress(KeyError):
267-
conn = self.conn_cache.get(db_conf)
270+
conn = self.conn_cache[cache_key]
268271
if not conn.is_closed:
269272
return conn
270273

@@ -276,5 +279,10 @@ def __call__(
276279
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
277280

278281
if shared:
279-
self.conn_cache.add(db_conf, conn)
282+
self.conn_cache[cache_key] = conn
280283
return conn
284+
285+
def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
286+
if isinstance(db_conf, dict):
287+
return tuple(db_conf.items())
288+
return db_conf

data_diff/sqeleton/databases/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import decimal
1212

1313
from runtype import dataclass
14+
from typing_extensions import Self
1415

15-
from ..utils import is_uuid, safezip, Self
16+
from ..utils import is_uuid, safezip
1617
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this
1718
from ..queries.ast_classes import Random
1819
from ..abcs.database_types import (
@@ -281,7 +282,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
281282
return math.floor(math.log(2**p, 10))
282283

283284
@classmethod
284-
def load_mixins(cls, *abstract_mixins) -> "Self":
285+
def load_mixins(cls, *abstract_mixins) -> Self:
285286
mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)}
286287

287288
class _DialectWithMixins(cls, *mixins, *abstract_mixins):

data_diff/sqeleton/queries/ast_classes.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from dataclasses import field
22
from datetime import datetime
3-
from typing import Any, Generator, List, Optional, Sequence, Union, Dict
3+
from typing import Any, Generator, List, Optional, Sequence, Type, Union, Dict
44

55
from runtype import dataclass
6+
from typing_extensions import Self
67

78
from ..utils import join_iter, ArithString
89
from ..abcs import Compilable
@@ -322,7 +323,7 @@ def when(self, *whens: Expr) -> "QB_When":
322323
return QB_When(self, whens[0])
323324
return QB_When(self, BinBoolOp("AND", whens))
324325

325-
def else_(self, then: Expr):
326+
def else_(self, then: Expr) -> Self:
326327
"""Add an 'else' clause to the case expression.
327328
328329
Can only be called once!
@@ -422,7 +423,7 @@ class TablePath(ExprNode, ITable):
422423
schema: Optional[Schema] = field(default=None, repr=False)
423424

424425
@property
425-
def source_table(self):
426+
def source_table(self) -> Self:
426427
return self
427428

428429
def compile(self, c: Compiler) -> str:
@@ -524,7 +525,7 @@ class Join(ExprNode, ITable, Root):
524525
columns: Sequence[Expr] = None
525526

526527
@property
527-
def source_table(self):
528+
def source_table(self) -> Self:
528529
return self
529530

530531
@property
@@ -533,7 +534,7 @@ def schema(self):
533534
s = self.source_tables[0].schema # TODO validate types match between both tables
534535
return type(s)({c.name: c.type for c in self.columns})
535536

536-
def on(self, *exprs) -> "Join":
537+
def on(self, *exprs) -> Self:
537538
"""Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)"""
538539
if len(exprs) == 1:
539540
(e,) = exprs
@@ -546,7 +547,7 @@ def on(self, *exprs) -> "Join":
546547

547548
return self.replace(on_exprs=(self.on_exprs or []) + exprs)
548549

549-
def select(self, *exprs, **named_exprs) -> ITable:
550+
def select(self, *exprs, **named_exprs) -> Union[Self, ITable]:
550551
"""Select fields to return from the JOIN operation
551552
552553
See Also: ``ITable.select()``
@@ -600,7 +601,7 @@ def source_table(self):
600601
def __post_init__(self):
601602
assert self.keys or self.values
602603

603-
def having(self, *exprs):
604+
def having(self, *exprs) -> Self:
604605
"""Add a 'HAVING' clause to the group-by"""
605606
exprs = args_as_tuple(exprs)
606607
exprs = _drop_skips(exprs)
@@ -610,7 +611,7 @@ def having(self, *exprs):
610611
resolve_names(self.table, exprs)
611612
return self.replace(having_exprs=(self.having_exprs or []) + exprs)
612613

613-
def agg(self, *exprs):
614+
def agg(self, *exprs) -> Self:
614615
"""Select aggregated fields for the group-by."""
615616
exprs = args_as_tuple(exprs)
616617
exprs = _drop_skips(exprs)
@@ -991,7 +992,7 @@ def compile(self, c: Compiler) -> str:
991992

992993
return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"
993994

994-
def returning(self, *exprs):
995+
def returning(self, *exprs) -> Self:
995996
"""Add a 'RETURNING' clause to the insert expression.
996997
997998
Note: Not all databases support this feature!

data_diff/sqeleton/queries/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Dict, Sequence, List
55

66
from runtype import dataclass
7+
from typing_extensions import Self
78

89
from ..utils import ArithString
910
from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable
@@ -79,7 +80,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
7980
self._counter[0] += 1
8081
return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")
8182

82-
def add_table_context(self, *tables: Sequence, **kw):
83+
def add_table_context(self, *tables: Sequence, **kw) -> Self:
8384
return self.replace(_table_context=self._table_context + list(tables), **kw)
8485

8586
def quote(self, s: str):

data_diff/sqeleton/utils.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,24 @@
22
Iterable,
33
Iterator,
44
MutableMapping,
5+
Type,
56
Union,
67
Any,
78
Sequence,
89
Dict,
9-
Hashable,
1010
TypeVar,
11-
TYPE_CHECKING,
1211
List,
1312
)
1413
from abc import abstractmethod
15-
from weakref import ref
1614
import math
1715
import string
1816
import re
1917
from uuid import UUID
2018
from urllib.parse import urlparse
2119

22-
# -- Common --
23-
24-
try:
25-
from typing import Self
26-
except ImportError:
27-
Self = Any
28-
29-
30-
class WeakCache:
31-
def __init__(self):
32-
self._cache = {}
33-
34-
def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable:
35-
if isinstance(k, dict):
36-
return tuple(k.items())
37-
return k
20+
from typing_extensions import Self
3821

39-
def add(self, key: Union[dict, Hashable], value: Any):
40-
key = self._hashable_key(key)
41-
self._cache[key] = ref(value)
42-
43-
def get(self, key: Union[dict, Hashable]) -> Any:
44-
key = self._hashable_key(key)
45-
46-
value = self._cache[key]()
47-
if value is None:
48-
del self._cache[key]
49-
raise KeyError(f"Key {key} not found, or no longer a valid reference")
50-
51-
return value
22+
# -- Common --
5223

5324

5425
def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
@@ -95,7 +66,7 @@ class CaseAwareMapping(MutableMapping[str, V]):
9566
def get_key(self, key: str) -> str:
9667
...
9768

98-
def new(self, initial=()):
69+
def new(self, initial=()) -> Self:
9970
return type(self)(initial)
10071

10172

@@ -144,10 +115,10 @@ def as_insensitive(self):
144115

145116
class ArithString:
146117
@classmethod
147-
def new(cls, *args, **kw):
118+
def new(cls, *args, **kw) -> Self:
148119
return cls(*args, **kw)
149120

150-
def range(self, other: "ArithString", count: int):
121+
def range(self, other: "ArithString", count: int) -> List[Self]:
151122
assert isinstance(other, ArithString)
152123
checkpoints = split_space(self.int, other.int, count)
153124
return [self.new(int=i) for i in checkpoints]
@@ -159,7 +130,7 @@ class ArithUUID(UUID, ArithString):
159130
def __int__(self):
160131
return self.int
161132

162-
def __add__(self, other: int):
133+
def __add__(self, other: int) -> Self:
163134
if isinstance(other, int):
164135
return self.new(int=self.int + other)
165136
return NotImplemented
@@ -231,7 +202,7 @@ def __len__(self):
231202
def __repr__(self):
232203
return f'alphanum"{self._str}"'
233204

234-
def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric":
205+
def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self:
235206
if isinstance(other, int):
236207
if other != 1:
237208
raise NotImplementedError("not implemented for arbitrary numbers")
@@ -240,7 +211,7 @@ def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric"
240211

241212
return NotImplemented
242213

243-
def range(self, other: "ArithAlphanumeric", count: int):
214+
def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]:
244215
assert isinstance(other, ArithAlphanumeric)
245216
n1, n2 = alphanums_to_numbers(self._str, other._str)
246217
split = split_space(n1, n2, count)
@@ -268,7 +239,7 @@ def __eq__(self, other):
268239
return NotImplemented
269240
return self._str == other._str
270241

271-
def new(self, *args, **kw):
242+
def new(self, *args, **kw) -> Self:
272243
return type(self)(*args, **kw, max_len=self._max_len)
273244

274245

0 commit comments

Comments
 (0)