11import os
22
33from pathlib import Path
4+
5+ from data_diff .cloud .datafold_api import TCloudApiDataSource
46from data_diff .diff_tables import Algorithm
57from .test_cli import run_datadiff_cli
68
@@ -179,12 +181,12 @@ def test_set_connection_snowflake_success_password(self):
179181
180182 DbtParser .set_connection (mock_self )
181183
184+ mock_self .set_casing_policy_for .assert_called_once_with ("snowflake" )
182185 self .assertIsInstance (mock_self .connection , dict )
183186 self .assertEqual (mock_self .connection .get ("driver" ), expected_driver )
184187 self .assertEqual (mock_self .connection .get ("user" ), expected_credentials ["user" ])
185188 self .assertEqual (mock_self .connection .get ("password" ), expected_credentials ["password" ])
186189 self .assertEqual (mock_self .connection .get ("key" ), None )
187- self .assertEqual (mock_self .requires_upper , True )
188190
189191 def test_set_connection_snowflake_success_key (self ):
190192 expected_driver = "snowflake"
@@ -194,12 +196,12 @@ def test_set_connection_snowflake_success_key(self):
194196
195197 DbtParser .set_connection (mock_self )
196198
199+ mock_self .set_casing_policy_for .assert_called_once_with ("snowflake" )
197200 self .assertIsInstance (mock_self .connection , dict )
198201 self .assertEqual (mock_self .connection .get ("driver" ), expected_driver )
199202 self .assertEqual (mock_self .connection .get ("user" ), expected_credentials ["user" ])
200203 self .assertEqual (mock_self .connection .get ("password" ), None )
201204 self .assertEqual (mock_self .connection .get ("key" ), expected_credentials ["private_key_path" ])
202- self .assertEqual (mock_self .requires_upper , True )
203205
204206 def test_set_connection_snowflake_success_key_and_passphrase (self ):
205207 expected_driver = "snowflake"
@@ -213,6 +215,7 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
213215
214216 DbtParser .set_connection (mock_self )
215217
218+ mock_self .set_casing_policy_for .assert_called_once_with ("snowflake" )
216219 self .assertIsInstance (mock_self .connection , dict )
217220 self .assertEqual (mock_self .connection .get ("driver" ), expected_driver )
218221 self .assertEqual (mock_self .connection .get ("user" ), expected_credentials ["user" ])
@@ -221,7 +224,6 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
221224 self .assertEqual (
222225 mock_self .connection .get ("private_key_passphrase" ), expected_credentials ["private_key_passphrase" ]
223226 )
224- self .assertEqual (mock_self .requires_upper , True )
225227
226228 def test_set_connection_snowflake_no_key_or_password (self ):
227229 expected_driver = "snowflake"
@@ -591,24 +593,22 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
591593 @patch ("data_diff.dbt._cloud_diff" )
592594 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
593595 @patch ("data_diff.dbt.rich.print" )
596+ @patch ("data_diff.dbt.DatafoldAPI" )
594597 def test_diff_is_cloud (
595- self , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
598+ self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api ,
596599 ):
597600 connection = {}
598601 threads = None
599602 where = "a_string"
600- host = "a_host"
601- api_key = "a_api_key"
602603 expected_dbt_vars_dict = {
603604 "prod_database" : "prod_db" ,
604605 "prod_schema" : "prod_schema" ,
605606 "datasource_id" : 1 ,
606607 }
607608 mock_dbt_parser_inst = Mock ()
608609 mock_model = Mock ()
609- api = DatafoldAPI (api_key = api_key , host = host )
610- mock_initialize_api .return_value = api
611-
610+ mock_api .get_data_source .return_value = TCloudApiDataSource (id = 1 , type = "snowflake" , name = "snowflake" )
611+ mock_initialize_api .return_value = mock_api
612612 mock_dbt_parser .return_value = mock_dbt_parser_inst
613613 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
614614 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -627,9 +627,11 @@ def test_diff_is_cloud(
627627 dbt_diff (is_cloud = True )
628628 mock_dbt_parser_inst .get_models .assert_called_once ()
629629 mock_dbt_parser_inst .set_connection .assert_not_called ()
630+ mock_dbt_parser_inst .set_casing_policy_for .assert_called_once ()
630631
631632 mock_initialize_api .assert_called_once ()
632- mock_cloud_diff .assert_called_once_with (diff_vars , 1 , api )
633+ mock_api .get_data_source .assert_called_once_with (1 )
634+ mock_cloud_diff .assert_called_once_with (diff_vars , 1 , mock_api )
633635 mock_local_diff .assert_not_called ()
634636 mock_print .assert_called_once ()
635637
@@ -805,8 +807,9 @@ def test_diff_only_prod_schema(
805807 @patch ("data_diff.dbt._cloud_diff" )
806808 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
807809 @patch ("data_diff.dbt.rich.print" )
810+ @patch ("data_diff.dbt.DatafoldAPI" )
808811 def test_diff_is_cloud_no_pks (
809- self , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
812+ self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
810813 ):
811814 connection = {}
812815 threads = None
@@ -816,13 +819,11 @@ def test_diff_is_cloud_no_pks(
816819 "prod_schema" : "prod_schema" ,
817820 "datasource_id" : 1 ,
818821 }
819- host = "a_host"
820- api_key = "a_api_key"
821822 mock_dbt_parser_inst = Mock ()
822823 mock_dbt_parser .return_value = mock_dbt_parser_inst
823824 mock_model = Mock ()
824- api = DatafoldAPI ( api_key = api_key , host = host )
825- mock_initialize_api . return_value = api
825+ mock_initialize_api . return_value = mock_api
826+ mock_api . get_data_source . return_value = TCloudApiDataSource ( id = 1 , type = "snowflake" , name = "snowflake" )
826827
827828 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
828829 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -840,6 +841,7 @@ def test_diff_is_cloud_no_pks(
840841 dbt_diff (is_cloud = True )
841842
842843 mock_initialize_api .assert_called_once ()
844+ mock_api .get_data_source .assert_called_once_with (1 )
843845 mock_dbt_parser_inst .get_models .assert_called_once ()
844846 mock_dbt_parser_inst .set_connection .assert_not_called ()
845847 mock_cloud_diff .assert_not_called ()
0 commit comments