Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit f45598c

Browse files
committed
Add tests of datatypes for databricks
1 parent d8be22b commit f45598c

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
TEST_BIGQUERY_CONN_STRING: str = None
1616
TEST_REDSHIFT_CONN_STRING: str = None
1717
TEST_ORACLE_CONN_STRING: str = None
18-
TEST_DATABRICKS_CONN_STRING: str = None
18+
TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI")
1919

2020
DEFAULT_N_SAMPLES = 50
2121
N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES))

tests/test_database_types.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,27 @@ def __iter__(self):
351351
],
352352
},
353353
db.Databricks: {
354+
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/int-type.html
355+
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/bigint-type.html
354356
"int": [
355357
"INT",
356358
"BIGINT",
357359
],
360+
361+
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/timestamp-type.html
358362
"datetime": [
359363
"TIMESTAMP",
360364
],
365+
366+
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/float-type.html
367+
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/double-type.html
368+
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/decimal-type.html
361369
"float": [
362370
"FLOAT",
363371
"DOUBLE",
364372
"DECIMAL(6, 2)",
365373
],
374+
366375
"uuid": [
367376
"STRING",
368377
]
@@ -379,7 +388,7 @@ def __iter__(self):
379388
) in source_type_categories.items(): # int, datetime, ..
380389
for source_type in source_types:
381390
for target_type in target_type_categories[type_category]:
382-
if CONNS.get(source_db, False) and CONNS.get(target_db, False):
391+
if (CONNS.get(source_db, False) and CONNS.get(target_db, False)):
383392
type_pairs.append(
384393
(
385394
source_db,
@@ -476,14 +485,14 @@ def _insert_to_table(conn, table, values, type):
476485
conn.query(insertion_query[0:-1], None)
477486
insertion_query = default_insertion_query
478487

479-
if not isinstance(conn, db.BigQuery):
488+
if not isinstance(conn, (db.BigQuery, db.Databricks)):
480489
conn.query("COMMIT", None)
481490

482491

483492
def _create_indexes(conn, table):
484493
# It is unfortunate that Presto doesn't support creating indexes...
485494
# Technically we could create it in the backing Postgres behind the scenes.
486-
if isinstance(conn, (db.Snowflake, db.Redshift, db.Presto, db.BigQuery)):
495+
if isinstance(conn, (db.Snowflake, db.Redshift, db.Presto, db.BigQuery, db.Databricks)):
487496
return
488497

489498
try:
@@ -516,7 +525,7 @@ def _create_table_with_indexes(conn, table, type):
516525
conn.query(f"CREATE TABLE IF NOT EXISTS {table}(id int, col {type})", None)
517526

518527
_create_indexes(conn, table)
519-
if not isinstance(conn, db.BigQuery):
528+
if not isinstance(conn, (db.BigQuery, db.Databricks)):
520529
conn.query("COMMIT", None)
521530

522531

@@ -527,7 +536,7 @@ def _drop_table_if_exists(conn, table):
527536
conn.query(f"DROP TABLE {table}", None)
528537
else:
529538
conn.query(f"DROP TABLE IF EXISTS {table}", None)
530-
if not isinstance(conn, db.BigQuery):
539+
if not isinstance(conn, (db.BigQuery, db.Databricks)):
531540
conn.query("COMMIT", None)
532541

533542

@@ -563,8 +572,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
563572

564573
src_table_path = src_conn.parse_table_name(src_table_name)
565574
dst_table_path = dst_conn.parse_table_name(dst_table_name)
566-
self.src_table = src_table = src_conn.quote(".".join(src_table_path))
567-
self.dst_table = dst_table = dst_table = dst_conn.quote(".".join(dst_table_path))
575+
self.src_table = src_table = ".".join(map(src_conn.quote, src_table_path))
576+
self.dst_table = dst_table = ".".join(map(dst_conn.quote, dst_table_path))
568577

569578
start = time.time()
570579
if not BENCHMARK:

0 commit comments

Comments
 (0)