Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 2a85528

Browse files
committed
provide a default value for temporary schema
1 parent 7a5769d commit 2a85528

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

data_diff/cloud/data_source.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,33 @@ def _validate_temp_schema(temp_schema: str):
5050
raise ValueError("Temporary schema should have a format <database>.<schema>")
5151

5252

53+
def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
54+
diff_vars = dbt_parser.get_datadiff_variables()
55+
config_prod_database = diff_vars.get("prod_database")
56+
config_prod_schema = diff_vars.get("prod_schema")
57+
if config_prod_database is not None and config_prod_schema is not None:
58+
temp_schema = f"{config_prod_database}.{config_prod_schema}"
59+
if db_type == "snowflake":
60+
return temp_schema.upper()
61+
elif db_type in {"pg", "postgres_aurora", "postgres_aws_rds", "redshift"}:
62+
return temp_schema.lower()
63+
return temp_schema
64+
return
65+
66+
5367
def create_ds_config(
5468
ds_config: TCloudApiDataSourceConfigSchema,
5569
data_source_name: str,
5670
dbt_parser: Optional[DbtParser] = None,
5771
) -> TDsConfig:
5872
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True, dbt_parser=dbt_parser)
5973

60-
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
74+
temp_schema = _get_temp_schema(dbt_parser=dbt_parser, db_type=ds_config.db_type) if dbt_parser else None
75+
if temp_schema:
76+
temp_schema = TemporarySchemaPrompt.ask("Temporary schema", default=temp_schema)
77+
else:
78+
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
79+
6180
float_tolerance = FloatPrompt.ask("Float tolerance", default=0.000001)
6281

6382
return TDsConfig(

0 commit comments

Comments
 (0)