@@ -50,14 +50,33 @@ def _validate_temp_schema(temp_schema: str):
5050 raise ValueError ("Temporary schema should have a format <database>.<schema>" )
5151
5252
53+ def _get_temp_schema (dbt_parser : DbtParser , db_type : str ) -> Optional [str ]:
54+ diff_vars = dbt_parser .get_datadiff_variables ()
55+ config_prod_database = diff_vars .get ("prod_database" )
56+ config_prod_schema = diff_vars .get ("prod_schema" )
57+ if config_prod_database is not None and config_prod_schema is not None :
58+ temp_schema = f"{ config_prod_database } .{ config_prod_schema } "
59+ if db_type == "snowflake" :
60+ return temp_schema .upper ()
61+ elif db_type in {"pg" , "postgres_aurora" , "postgres_aws_rds" , "redshift" }:
62+ return temp_schema .lower ()
63+ return temp_schema
64+ return
65+
66+
5367def create_ds_config (
5468 ds_config : TCloudApiDataSourceConfigSchema ,
5569 data_source_name : str ,
5670 dbt_parser : Optional [DbtParser ] = None ,
5771) -> TDsConfig :
5872 options = _parse_ds_credentials (ds_config = ds_config , only_basic_settings = True , dbt_parser = dbt_parser )
5973
60- temp_schema = TemporarySchemaPrompt .ask ("Temporary schema (<database>.<schema>)" )
74+ temp_schema = _get_temp_schema (dbt_parser = dbt_parser , db_type = ds_config .db_type ) if dbt_parser else None
75+ if temp_schema :
76+ temp_schema = TemporarySchemaPrompt .ask ("Temporary schema" , default = temp_schema )
77+ else :
78+ temp_schema = TemporarySchemaPrompt .ask ("Temporary schema (<database>.<schema>)" )
79+
6180 float_tolerance = FloatPrompt .ask ("Float tolerance" , default = 0.000001 )
6281
6382 return TDsConfig (
0 commit comments