Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e0060e1

Browse files
authored
Merge pull request #579 from datafold/casing-policy-for-cloud-diffs
Use proper casing policy for --cloud diffs
2 parents 4f4097d + 8ae2965 commit e0060e1

File tree

4 files changed

+34
-16
lines changed

4 files changed

+34
-16
lines changed

data_diff/cloud/datafold_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ def get_data_sources(self) -> List[TCloudApiDataSource]:
198198
rv.raise_for_status()
199199
return [TCloudApiDataSource(**item) for item in rv.json()]
200200

201+
def get_data_source(self, data_source_id: int) -> TCloudApiDataSource:
202+
rv = self.make_get_request(url=f"api/v1/data_sources/{data_source_id}")
203+
rv.raise_for_status()
204+
return TCloudApiDataSource(**rv.json())
205+
201206
def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource:
202207
payload = config.dict()
203208
if config.type == "bigquery":

data_diff/dbt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def dbt_diff(
8989
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. "
9090
"\nvars:\n data_diff:\n datasource_id: 1234"
9191
)
92+
93+
data_source = api.get_data_source(datasource_id)
94+
dbt_parser.set_casing_policy_for(connection_type=data_source.type)
9295
rich.print("[green][bold]\nDiffs in progress...[/][/]\n")
9396

9497
else:

data_diff/dbt_parser.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def get_connection_creds(self) -> Tuple[Dict[str, str], str]:
238238

239239
def set_connection(self):
240240
credentials, conn_type = self.get_connection_creds()
241+
self.set_casing_policy_for(conn_type)
241242

242243
if conn_type == "snowflake":
243244
conn_info = {
@@ -252,7 +253,6 @@ def set_connection(self):
252253
"client_session_keep_alive": credentials.get("client_session_keep_alive", False),
253254
}
254255
self.threads = credentials.get("threads")
255-
self.requires_upper = True
256256

257257
if credentials.get("private_key_path") is not None:
258258
if credentials.get("password") is not None:
@@ -405,3 +405,11 @@ def _parse_concat_pk_definition(self, definition: str) -> List[str]:
405405

406406
stripped_columns = [col.strip('" ()') for col in columns]
407407
return stripped_columns
408+
409+
def set_casing_policy_for(self, connection_type: str):
410+
"""
411+
Set casing policy for identifiers: database, schema, table, column, etc.
412+
Correct policy depends on the type of the database, because some databases (e.g. Snowflake)
413+
use upper case identifiers by default, while others (e.g. Postgres) use lower case.
414+
"""
415+
self.requires_upper = connection_type == "snowflake"

tests/test_dbt.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22

33
from pathlib import Path
4+
5+
from data_diff.cloud.datafold_api import TCloudApiDataSource
46
from data_diff.diff_tables import Algorithm
57
from .test_cli import run_datadiff_cli
68

@@ -183,12 +185,12 @@ def test_set_connection_snowflake_success_password(self):
183185

184186
DbtParser.set_connection(mock_self)
185187

188+
mock_self.set_casing_policy_for.assert_called_once_with("snowflake")
186189
self.assertIsInstance(mock_self.connection, dict)
187190
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
188191
self.assertEqual(mock_self.connection.get("user"), expected_credentials["user"])
189192
self.assertEqual(mock_self.connection.get("password"), expected_credentials["password"])
190193
self.assertEqual(mock_self.connection.get("key"), None)
191-
self.assertEqual(mock_self.requires_upper, True)
192194

193195
def test_set_connection_snowflake_success_key(self):
194196
expected_driver = "snowflake"
@@ -198,12 +200,12 @@ def test_set_connection_snowflake_success_key(self):
198200

199201
DbtParser.set_connection(mock_self)
200202

203+
mock_self.set_casing_policy_for.assert_called_once_with("snowflake")
201204
self.assertIsInstance(mock_self.connection, dict)
202205
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
203206
self.assertEqual(mock_self.connection.get("user"), expected_credentials["user"])
204207
self.assertEqual(mock_self.connection.get("password"), None)
205208
self.assertEqual(mock_self.connection.get("key"), expected_credentials["private_key_path"])
206-
self.assertEqual(mock_self.requires_upper, True)
207209

208210
def test_set_connection_snowflake_success_key_and_passphrase(self):
209211
expected_driver = "snowflake"
@@ -217,6 +219,7 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
217219

218220
DbtParser.set_connection(mock_self)
219221

222+
mock_self.set_casing_policy_for.assert_called_once_with("snowflake")
220223
self.assertIsInstance(mock_self.connection, dict)
221224
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
222225
self.assertEqual(mock_self.connection.get("user"), expected_credentials["user"])
@@ -225,7 +228,6 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
225228
self.assertEqual(
226229
mock_self.connection.get("private_key_passphrase"), expected_credentials["private_key_passphrase"]
227230
)
228-
self.assertEqual(mock_self.requires_upper, True)
229231

230232
def test_set_connection_snowflake_no_key_or_password(self):
231233
expected_driver = "snowflake"
@@ -609,24 +611,22 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
609611
@patch("data_diff.dbt._cloud_diff")
610612
@patch("data_diff.dbt_parser.DbtParser.__new__")
611613
@patch("data_diff.dbt.rich.print")
614+
@patch("data_diff.dbt.DatafoldAPI")
612615
def test_diff_is_cloud(
613-
self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
616+
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api,
614617
):
615618
connection = {}
616619
threads = None
617620
where = "a_string"
618-
host = "a_host"
619-
api_key = "a_api_key"
620621
expected_dbt_vars_dict = {
621622
"prod_database": "prod_db",
622623
"prod_schema": "prod_schema",
623624
"datasource_id": 1,
624625
}
625626
mock_dbt_parser_inst = Mock()
626627
mock_model = Mock()
627-
api = DatafoldAPI(api_key=api_key, host=host)
628-
mock_initialize_api.return_value = api
629-
628+
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")
629+
mock_initialize_api.return_value = mock_api
630630
mock_dbt_parser.return_value = mock_dbt_parser_inst
631631
mock_dbt_parser_inst.get_models.return_value = [mock_model]
632632
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
@@ -645,9 +645,11 @@ def test_diff_is_cloud(
645645
dbt_diff(is_cloud=True)
646646
mock_dbt_parser_inst.get_models.assert_called_once()
647647
mock_dbt_parser_inst.set_connection.assert_not_called()
648+
mock_dbt_parser_inst.set_casing_policy_for.assert_called_once()
648649

649650
mock_initialize_api.assert_called_once()
650-
mock_cloud_diff.assert_called_once_with(diff_vars, 1, api)
651+
mock_api.get_data_source.assert_called_once_with(1)
652+
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api)
651653
mock_local_diff.assert_not_called()
652654
mock_print.assert_called_once()
653655

@@ -823,8 +825,9 @@ def test_diff_only_prod_schema(
823825
@patch("data_diff.dbt._cloud_diff")
824826
@patch("data_diff.dbt_parser.DbtParser.__new__")
825827
@patch("data_diff.dbt.rich.print")
828+
@patch("data_diff.dbt.DatafoldAPI")
826829
def test_diff_is_cloud_no_pks(
827-
self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
830+
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
828831
):
829832
connection = {}
830833
threads = None
@@ -834,13 +837,11 @@ def test_diff_is_cloud_no_pks(
834837
"prod_schema": "prod_schema",
835838
"datasource_id": 1,
836839
}
837-
host = "a_host"
838-
api_key = "a_api_key"
839840
mock_dbt_parser_inst = Mock()
840841
mock_dbt_parser.return_value = mock_dbt_parser_inst
841842
mock_model = Mock()
842-
api = DatafoldAPI(api_key=api_key, host=host)
843-
mock_initialize_api.return_value = api
843+
mock_initialize_api.return_value = mock_api
844+
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")
844845

845846
mock_dbt_parser_inst.get_models.return_value = [mock_model]
846847
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
@@ -858,6 +859,7 @@ def test_diff_is_cloud_no_pks(
858859
dbt_diff(is_cloud=True)
859860

860861
mock_initialize_api.assert_called_once()
862+
mock_api.get_data_source.assert_called_once_with(1)
861863
mock_dbt_parser_inst.get_models.assert_called_once()
862864
mock_dbt_parser_inst.set_connection.assert_not_called()
863865
mock_cloud_diff.assert_not_called()

0 commit comments

Comments
 (0)