|
| 1 | +from typing import Optional, Type |
| 2 | + |
| 3 | +from .base import ( |
| 4 | + MD5_HEXDIGITS, |
| 5 | + CHECKSUM_HEXDIGITS, |
| 6 | + TIMESTAMP_PRECISION_POS, |
| 7 | + ThreadedDatabase, |
| 8 | + import_helper, |
| 9 | + ConnectError, |
| 10 | +) |
| 11 | +from .database_types import ColType, Decimal, Float, Integer, FractionalType, Native_UUID, TemporalType, Text, Timestamp |
| 12 | + |
| 13 | + |
| 14 | +@import_helper("clickhouse") |
| 15 | +def import_clickhouse(): |
| 16 | + import clickhouse_driver |
| 17 | + |
| 18 | + return clickhouse_driver |
| 19 | + |
| 20 | + |
| 21 | +class Clickhouse(ThreadedDatabase): |
| 22 | + TYPE_CLASSES = { |
| 23 | + "Int8": Integer, |
| 24 | + "Int16": Integer, |
| 25 | + "Int32": Integer, |
| 26 | + "Int64": Integer, |
| 27 | + "Int128": Integer, |
| 28 | + "Int256": Integer, |
| 29 | + "UInt8": Integer, |
| 30 | + "UInt16": Integer, |
| 31 | + "UInt32": Integer, |
| 32 | + "UInt64": Integer, |
| 33 | + "UInt128": Integer, |
| 34 | + "UInt256": Integer, |
| 35 | + "Float32": Float, |
| 36 | + "Float64": Float, |
| 37 | + "Decimal": Decimal, |
| 38 | + "UUID": Native_UUID, |
| 39 | + "String": Text, |
| 40 | + "FixedString": Text, |
| 41 | + "DateTime": Timestamp, |
| 42 | + "DateTime64": Timestamp, |
| 43 | + } |
| 44 | + ROUNDS_ON_PREC_LOSS = False |
| 45 | + |
| 46 | + def __init__(self, *, thread_count: int, **kw): |
| 47 | + super().__init__(thread_count=thread_count) |
| 48 | + |
| 49 | + self._args = kw |
| 50 | + # In Clickhouse database and schema are the same |
| 51 | + self.default_schema = kw["database"] |
| 52 | + |
| 53 | + def create_connection(self): |
| 54 | + clickhouse = import_clickhouse() |
| 55 | + |
| 56 | + class SingleConnection(clickhouse.dbapi.connection.Connection): |
| 57 | + """Not thread-safe connection to Clickhouse""" |
| 58 | + |
| 59 | + def cursor(self, cursor_factory=None): |
| 60 | + if not len(self.cursors): |
| 61 | + _ = super().cursor() |
| 62 | + return self.cursors[0] |
| 63 | + |
| 64 | + try: |
| 65 | + return SingleConnection(**self._args) |
| 66 | + except clickhouse.OperationError as e: |
| 67 | + raise ConnectError(*e.args) from e |
| 68 | + |
| 69 | + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: |
| 70 | + nullable_prefix = "Nullable(" |
| 71 | + if type_repr.startswith(nullable_prefix): |
| 72 | + type_repr = type_repr[len(nullable_prefix):].rstrip(")") |
| 73 | + |
| 74 | + if type_repr.startswith("Decimal"): |
| 75 | + type_repr = "Decimal" |
| 76 | + elif type_repr.startswith("FixedString"): |
| 77 | + type_repr = "FixedString" |
| 78 | + elif type_repr.startswith("DateTime64"): |
| 79 | + type_repr = "DateTime64" |
| 80 | + |
| 81 | + return self.TYPE_CLASSES.get(type_repr) |
| 82 | + |
| 83 | + def quote(self, s: str) -> str: |
| 84 | + return f'"{s}"' |
| 85 | + |
| 86 | + def md5_to_int(self, s: str) -> str: |
| 87 | + substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS |
| 88 | + return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" |
| 89 | + |
| 90 | + def to_string(self, s: str) -> str: |
| 91 | + return f"toString({s})" |
| 92 | + |
| 93 | + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: |
| 94 | + prec= coltype.precision |
| 95 | + if coltype.rounds: |
| 96 | + timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" |
| 97 | + return self.to_string(timestamp) |
| 98 | + |
| 99 | + fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" |
| 100 | + fractional = f"lpad({self.to_string(fractional)}, 6, '0')" |
| 101 | + value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" |
| 102 | + return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" |
| 103 | + |
| 104 | + def _convert_db_precision_to_digits(self, p: int) -> int: |
| 105 | + # Done the same as for PostgreSQL but need to rewrite in another way |
| 106 | + # because it does not help for float with a big integer part. |
| 107 | + return super()._convert_db_precision_to_digits(p) - 2 |
| 108 | + |
| 109 | + def normalize_number(self, value: str, coltype: FractionalType) -> str: |
| 110 | + # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. |
| 111 | + # For example: |
| 112 | + # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 |
| 113 | + # select toString(toDecimal128(1.00, 2)); -- the result is 1 |
| 114 | + # So, we should use some custom approach to save these trailing zeros. |
| 115 | + # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. |
| 116 | + # For examples above it looks like: |
| 117 | + # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 |
| 118 | + # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 |
| 119 | + # So, the algorithm is: |
| 120 | + # 1. Cast to decimal with precision + 1 |
| 121 | + # 2. Add a small value 10^(-precision-1) |
| 122 | + # 3. Cast the result to string |
| 123 | + # 4. Drop the extra digit from the string. To do that, we need to slice the string |
| 124 | + # with length = digits in an integer part + 1 (symbol of ".") + precision |
| 125 | + |
| 126 | + if coltype.precision == 0: |
| 127 | + return self.to_string(f"round({value})") |
| 128 | + |
| 129 | + precision = coltype.precision |
| 130 | + # TODO: too complex, is there better performance way? |
| 131 | + value = f""" |
| 132 | + if({value} >= 0, '', '-') || left( |
| 133 | + toString( |
| 134 | + toDecimal128( |
| 135 | + round(abs({value}), {precision}), |
| 136 | + {precision} + 1 |
| 137 | + ) |
| 138 | + + |
| 139 | + toDecimal128( |
| 140 | + exp10(-{precision + 1}), |
| 141 | + {precision} + 1 |
| 142 | + ) |
| 143 | + ), |
| 144 | + toUInt8( |
| 145 | + greatest( |
| 146 | + floor(log10(abs({value}))) + 1, |
| 147 | + 1 |
| 148 | + ) |
| 149 | + ) + 1 + {precision} |
| 150 | + ) |
| 151 | + """ |
| 152 | + return value |
0 commit comments