Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 8ae2965

Browse files
committed
use proper casing policy for --cloud diffs
1 parent 006af61 commit 8ae2965

File tree

4 files changed

+34
-16
lines changed

4 files changed

+34
-16
lines changed

data_diff/cloud/datafold_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ def get_data_sources(self) -> List[TCloudApiDataSource]:
198198
rv.raise_for_status()
199199
return [TCloudApiDataSource(**item) for item in rv.json()]
200200

201+
def get_data_source(self, data_source_id: int) -> TCloudApiDataSource:
202+
rv = self.make_get_request(url=f"api/v1/data_sources/{data_source_id}")
203+
rv.raise_for_status()
204+
return TCloudApiDataSource(**rv.json())
205+
201206
def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource:
202207
payload = config.dict()
203208
if config.type == "bigquery":

data_diff/dbt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def dbt_diff(
8989
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. "
9090
"\nvars:\n data_diff:\n datasource_id: 1234"
9191
)
92+
93+
data_source = api.get_data_source(datasource_id)
94+
dbt_parser.set_casing_policy_for(connection_type=data_source.type)
9295
rich.print("[green][bold]\nDiffs in progress...[/][/]\n")
9396

9497
else:

data_diff/dbt_parser.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def get_connection_creds(self) -> Tuple[Dict[str, str], str]:
243243

244244
def set_connection(self):
245245
credentials, conn_type = self.get_connection_creds()
246+
self.set_casing_policy_for(conn_type)
246247

247248
if conn_type == "snowflake":
248249
conn_info = {
@@ -257,7 +258,6 @@ def set_connection(self):
257258
"client_session_keep_alive": credentials.get("client_session_keep_alive", False),
258259
}
259260
self.threads = credentials.get("threads")
260-
self.requires_upper = True
261261

262262
if credentials.get("private_key_path") is not None:
263263
if credentials.get("password") is not None:
@@ -410,3 +410,11 @@ def _parse_concat_pk_definition(self, definition: str) -> List[str]:
410410

411411
stripped_columns = [col.strip('" ()') for col in columns]
412412
return stripped_columns
413+
414+
def set_casing_policy_for(self, connection_type: str):
415+
"""
416+
Set casing policy for identifiers: database, schema, table, column, etc.
417+
Correct policy depends on the type of the database, because some databases (e.g. Snowflake)
418+
use upper case identifiers by default, while others (e.g. Postgres) use lower case.
419+
"""
420+
self.requires_upper = connection_type == "snowflake"

tests/test_dbt.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22

33
from pathlib import Path
4+
5+
from data_diff.cloud.datafold_api import TCloudApiDataSource
46
from data_diff.diff_tables import Algorithm
57
from .test_cli import run_datadiff_cli
68

@@ -179,12 +181,12 @@ def test_set_connection_snowflake_success_password(self):
179181

180182
DbtParser.set_connection(mock_self)
181183

184+
mock_self.set_casing_policy_for.assert_called_once_with("snowflake")
182185
self.assertIsInstance(mock_self.connection, dict)
183186
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
184187
self.assertEqual(mock_self.connection.get("user"), expected_credentials["user"])
185188
self.assertEqual(mock_self.connection.get("password"), expected_credentials["password"])
186189
self.assertEqual(mock_self.connection.get("key"), None)
187-
self.assertEqual(mock_self.requires_upper, True)
188190

189191
def test_set_connection_snowflake_success_key(self):
190192
expected_driver = "snowflake"
@@ -194,12 +196,12 @@ def test_set_connection_snowflake_success_key(self):
194196

195197
DbtParser.set_connection(mock_self)
196198

199+
mock_self.set_casing_policy_for.assert_called_once_with("snowflake")
197200
self.assertIsInstance(mock_self.connection, dict)
198201
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
199202
self.assertEqual(mock_self.connection.get("user"), expected_credentials["user"])
200203
self.assertEqual(mock_self.connection.get("password"), None)
201204
self.assertEqual(mock_self.connection.get("key"), expected_credentials["private_key_path"])
202-
self.assertEqual(mock_self.requires_upper, True)
203205

204206
def test_set_connection_snowflake_success_key_and_passphrase(self):
205207
expected_driver = "snowflake"
@@ -213,6 +215,7 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
213215

214216
DbtParser.set_connection(mock_self)
215217

218+
mock_self.set_casing_policy_for.assert_called_once_with("snowflake")
216219
self.assertIsInstance(mock_self.connection, dict)
217220
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
218221
self.assertEqual(mock_self.connection.get("user"), expected_credentials["user"])
@@ -221,7 +224,6 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
221224
self.assertEqual(
222225
mock_self.connection.get("private_key_passphrase"), expected_credentials["private_key_passphrase"]
223226
)
224-
self.assertEqual(mock_self.requires_upper, True)
225227

226228
def test_set_connection_snowflake_no_key_or_password(self):
227229
expected_driver = "snowflake"
@@ -591,24 +593,22 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
591593
@patch("data_diff.dbt._cloud_diff")
592594
@patch("data_diff.dbt_parser.DbtParser.__new__")
593595
@patch("data_diff.dbt.rich.print")
596+
@patch("data_diff.dbt.DatafoldAPI")
594597
def test_diff_is_cloud(
595-
self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
598+
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api,
596599
):
597600
connection = {}
598601
threads = None
599602
where = "a_string"
600-
host = "a_host"
601-
api_key = "a_api_key"
602603
expected_dbt_vars_dict = {
603604
"prod_database": "prod_db",
604605
"prod_schema": "prod_schema",
605606
"datasource_id": 1,
606607
}
607608
mock_dbt_parser_inst = Mock()
608609
mock_model = Mock()
609-
api = DatafoldAPI(api_key=api_key, host=host)
610-
mock_initialize_api.return_value = api
611-
610+
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")
611+
mock_initialize_api.return_value = mock_api
612612
mock_dbt_parser.return_value = mock_dbt_parser_inst
613613
mock_dbt_parser_inst.get_models.return_value = [mock_model]
614614
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
@@ -627,9 +627,11 @@ def test_diff_is_cloud(
627627
dbt_diff(is_cloud=True)
628628
mock_dbt_parser_inst.get_models.assert_called_once()
629629
mock_dbt_parser_inst.set_connection.assert_not_called()
630+
mock_dbt_parser_inst.set_casing_policy_for.assert_called_once()
630631

631632
mock_initialize_api.assert_called_once()
632-
mock_cloud_diff.assert_called_once_with(diff_vars, 1, api)
633+
mock_api.get_data_source.assert_called_once_with(1)
634+
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api)
633635
mock_local_diff.assert_not_called()
634636
mock_print.assert_called_once()
635637

@@ -805,8 +807,9 @@ def test_diff_only_prod_schema(
805807
@patch("data_diff.dbt._cloud_diff")
806808
@patch("data_diff.dbt_parser.DbtParser.__new__")
807809
@patch("data_diff.dbt.rich.print")
810+
@patch("data_diff.dbt.DatafoldAPI")
808811
def test_diff_is_cloud_no_pks(
809-
self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
812+
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
810813
):
811814
connection = {}
812815
threads = None
@@ -816,13 +819,11 @@ def test_diff_is_cloud_no_pks(
816819
"prod_schema": "prod_schema",
817820
"datasource_id": 1,
818821
}
819-
host = "a_host"
820-
api_key = "a_api_key"
821822
mock_dbt_parser_inst = Mock()
822823
mock_dbt_parser.return_value = mock_dbt_parser_inst
823824
mock_model = Mock()
824-
api = DatafoldAPI(api_key=api_key, host=host)
825-
mock_initialize_api.return_value = api
825+
mock_initialize_api.return_value = mock_api
826+
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")
826827

827828
mock_dbt_parser_inst.get_models.return_value = [mock_model]
828829
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
@@ -840,6 +841,7 @@ def test_diff_is_cloud_no_pks(
840841
dbt_diff(is_cloud=True)
841842

842843
mock_initialize_api.assert_called_once()
844+
mock_api.get_data_source.assert_called_once_with(1)
843845
mock_dbt_parser_inst.get_models.assert_called_once()
844846
mock_dbt_parser_inst.set_connection.assert_not_called()
845847
mock_cloud_diff.assert_not_called()

0 commit comments

Comments
 (0)