11import os
22
33from pathlib import Path
4+
5+ from data_diff .cloud .datafold_api import TCloudApiDataSource
46from data_diff .diff_tables import Algorithm
57from .test_cli import run_datadiff_cli
68
@@ -183,12 +185,12 @@ def test_set_connection_snowflake_success_password(self):
183185
184186 DbtParser .set_connection (mock_self )
185187
188+ mock_self .set_casing_policy_for .assert_called_once_with ("snowflake" )
186189 self .assertIsInstance (mock_self .connection , dict )
187190 self .assertEqual (mock_self .connection .get ("driver" ), expected_driver )
188191 self .assertEqual (mock_self .connection .get ("user" ), expected_credentials ["user" ])
189192 self .assertEqual (mock_self .connection .get ("password" ), expected_credentials ["password" ])
190193 self .assertEqual (mock_self .connection .get ("key" ), None )
191- self .assertEqual (mock_self .requires_upper , True )
192194
193195 def test_set_connection_snowflake_success_key (self ):
194196 expected_driver = "snowflake"
@@ -198,12 +200,12 @@ def test_set_connection_snowflake_success_key(self):
198200
199201 DbtParser .set_connection (mock_self )
200202
203+ mock_self .set_casing_policy_for .assert_called_once_with ("snowflake" )
201204 self .assertIsInstance (mock_self .connection , dict )
202205 self .assertEqual (mock_self .connection .get ("driver" ), expected_driver )
203206 self .assertEqual (mock_self .connection .get ("user" ), expected_credentials ["user" ])
204207 self .assertEqual (mock_self .connection .get ("password" ), None )
205208 self .assertEqual (mock_self .connection .get ("key" ), expected_credentials ["private_key_path" ])
206- self .assertEqual (mock_self .requires_upper , True )
207209
208210 def test_set_connection_snowflake_success_key_and_passphrase (self ):
209211 expected_driver = "snowflake"
@@ -217,6 +219,7 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
217219
218220 DbtParser .set_connection (mock_self )
219221
222+ mock_self .set_casing_policy_for .assert_called_once_with ("snowflake" )
220223 self .assertIsInstance (mock_self .connection , dict )
221224 self .assertEqual (mock_self .connection .get ("driver" ), expected_driver )
222225 self .assertEqual (mock_self .connection .get ("user" ), expected_credentials ["user" ])
@@ -225,7 +228,6 @@ def test_set_connection_snowflake_success_key_and_passphrase(self):
225228 self .assertEqual (
226229 mock_self .connection .get ("private_key_passphrase" ), expected_credentials ["private_key_passphrase" ]
227230 )
228- self .assertEqual (mock_self .requires_upper , True )
229231
230232 def test_set_connection_snowflake_no_key_or_password (self ):
231233 expected_driver = "snowflake"
@@ -609,24 +611,22 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
609611 @patch ("data_diff.dbt._cloud_diff" )
610612 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
611613 @patch ("data_diff.dbt.rich.print" )
614+ @patch ("data_diff.dbt.DatafoldAPI" )
612615 def test_diff_is_cloud (
613- self , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
616+ self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api ,
614617 ):
615618 connection = {}
616619 threads = None
617620 where = "a_string"
618- host = "a_host"
619- api_key = "a_api_key"
620621 expected_dbt_vars_dict = {
621622 "prod_database" : "prod_db" ,
622623 "prod_schema" : "prod_schema" ,
623624 "datasource_id" : 1 ,
624625 }
625626 mock_dbt_parser_inst = Mock ()
626627 mock_model = Mock ()
627- api = DatafoldAPI (api_key = api_key , host = host )
628- mock_initialize_api .return_value = api
629-
628+ mock_api .get_data_source .return_value = TCloudApiDataSource (id = 1 , type = "snowflake" , name = "snowflake" )
629+ mock_initialize_api .return_value = mock_api
630630 mock_dbt_parser .return_value = mock_dbt_parser_inst
631631 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
632632 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -645,9 +645,11 @@ def test_diff_is_cloud(
645645 dbt_diff (is_cloud = True )
646646 mock_dbt_parser_inst .get_models .assert_called_once ()
647647 mock_dbt_parser_inst .set_connection .assert_not_called ()
648+ mock_dbt_parser_inst .set_casing_policy_for .assert_called_once ()
648649
649650 mock_initialize_api .assert_called_once ()
650- mock_cloud_diff .assert_called_once_with (diff_vars , 1 , api )
651+ mock_api .get_data_source .assert_called_once_with (1 )
652+ mock_cloud_diff .assert_called_once_with (diff_vars , 1 , mock_api )
651653 mock_local_diff .assert_not_called ()
652654 mock_print .assert_called_once ()
653655
@@ -823,8 +825,9 @@ def test_diff_only_prod_schema(
823825 @patch ("data_diff.dbt._cloud_diff" )
824826 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
825827 @patch ("data_diff.dbt.rich.print" )
828+ @patch ("data_diff.dbt.DatafoldAPI" )
826829 def test_diff_is_cloud_no_pks (
827- self , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
830+ self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
828831 ):
829832 connection = {}
830833 threads = None
@@ -834,13 +837,11 @@ def test_diff_is_cloud_no_pks(
834837 "prod_schema" : "prod_schema" ,
835838 "datasource_id" : 1 ,
836839 }
837- host = "a_host"
838- api_key = "a_api_key"
839840 mock_dbt_parser_inst = Mock ()
840841 mock_dbt_parser .return_value = mock_dbt_parser_inst
841842 mock_model = Mock ()
842- api = DatafoldAPI ( api_key = api_key , host = host )
843- mock_initialize_api . return_value = api
843+ mock_initialize_api . return_value = mock_api
844+ mock_api . get_data_source . return_value = TCloudApiDataSource ( id = 1 , type = "snowflake" , name = "snowflake" )
844845
845846 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
846847 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -858,6 +859,7 @@ def test_diff_is_cloud_no_pks(
858859 dbt_diff (is_cloud = True )
859860
860861 mock_initialize_api .assert_called_once ()
862+ mock_api .get_data_source .assert_called_once_with (1 )
861863 mock_dbt_parser_inst .get_models .assert_called_once ()
862864 mock_dbt_parser_inst .set_connection .assert_not_called ()
863865 mock_cloud_diff .assert_not_called ()
0 commit comments