11import os
22import time
33import webbrowser
4+ import pydantic
45import rich
56from rich .prompt import Confirm
67
7- from dataclasses import dataclass
88from typing import List , Optional , Dict
99from .utils import dbt_diff_string_template , getLogger
1010from pathlib import Path
3232from . import connect_to_table , diff_tables , Algorithm
3333
3434
35- @dataclass
36- class DiffVars :
35+ class TDiffVars (pydantic .BaseModel ):
3736 dev_path : List [str ]
3837 prod_path : List [str ]
3938 primary_keys : List [str ]
40- connection : Dict [str , str ]
41- threads : Optional [int ]
39+ connection : Dict [str , Optional [ str ] ]
40+ threads : Optional [int ] = None
4241 where_filter : Optional [str ] = None
42+ include_columns : List [str ]
43+ exclude_columns : List [str ]
4344
4445
4546def dbt_diff (
@@ -122,7 +123,7 @@ def _get_diff_vars(
122123 config_prod_schema : Optional [str ],
123124 config_prod_custom_schema : Optional [str ],
124125 model ,
125- ) -> DiffVars :
126+ ) -> TDiffVars :
126127 dev_database = model .database
127128 dev_schema = model .schema_
128129
@@ -156,19 +157,21 @@ def _get_diff_vars(
156157 dev_qualified_list = [x for x in [dev_database , dev_schema , model .alias ] if x ]
157158 prod_qualified_list = [x for x in [prod_database , prod_schema , model .alias ] if x ]
158159
159- where_filter = None
160- if model .meta :
161- try :
162- where_filter = model .meta ["datafold" ]["datadiff" ]["filter" ]
163- except KeyError :
164- pass
165-
166- return DiffVars (
167- dev_qualified_list , prod_qualified_list , primary_keys , dbt_parser .connection , dbt_parser .threads , where_filter
160+ datadiff_model_config = dbt_parser .get_datadiff_model_config (model .meta )
161+
162+ return TDiffVars (
163+ dev_path = dev_qualified_list ,
164+ prod_path = prod_qualified_list ,
165+ primary_keys = primary_keys ,
166+ connection = dbt_parser .connection ,
167+ threads = dbt_parser .threads ,
168+ where_filter = datadiff_model_config .where_filter ,
169+ include_columns = datadiff_model_config .include_columns ,
170+ exclude_columns = datadiff_model_config .exclude_columns ,
168171 )
169172
170173
171- def _local_diff (diff_vars : DiffVars ) -> None :
174+ def _local_diff (diff_vars : TDiffVars ) -> None :
172175 column_diffs_str = ""
173176 dev_qualified_str = "." .join (diff_vars .dev_path )
174177 prod_qualified_str = "." .join (diff_vars .prod_path )
@@ -189,18 +192,25 @@ def _local_diff(diff_vars: DiffVars) -> None:
189192 rich .print (diff_output_str )
190193 return
191194
192- mutual_set = set (table1_columns ) & set (table2_columns )
193- table1_set_diff = list (set (table1_columns ) - set (table2_columns ))
194- table2_set_diff = list (set (table2_columns ) - set (table1_columns ))
195+ column_set = set (table1_columns ).intersection (table2_columns )
196+ table1_diff = set (table1_columns ).difference (table2_columns )
197+ table2_diff = set (table2_columns ).difference (table1_columns )
198+
199+ if table1_diff :
200+ column_diffs_str += f"Column(s) added: { table1_diff } \n "
201+
202+ if table2_diff :
203+ column_diffs_str += f"Column(s) removed: { table2_diff } \n "
204+
205+ column_set = column_set - set (diff_vars .primary_keys )
195206
196- if table1_set_diff :
197- column_diffs_str += "Column(s) added: " + str ( table1_set_diff ) + " \n "
207+ if diff_vars . include_columns :
208+ column_set = { x for x in column_set if x . upper () in [ y . upper () for y in diff_vars . include_columns ]}
198209
199- if table2_set_diff :
200- column_diffs_str += "Column(s) removed: " + str ( table2_set_diff ) + " \n "
210+ if diff_vars . exclude_columns :
211+ column_set = { x for x in column_set if x . upper () not in [ y . upper () for y in diff_vars . exclude_columns ]}
201212
202- mutual_set = mutual_set - set (diff_vars .primary_keys )
203- extra_columns = tuple (mutual_set )
213+ extra_columns = tuple (column_set )
204214
205215 diff = diff_tables (
206216 table1 ,
@@ -250,7 +260,7 @@ def _initialize_api() -> Optional[DatafoldAPI]:
250260 return DatafoldAPI (api_key = api_key , host = datafold_host )
251261
252262
253- def _cloud_diff (diff_vars : DiffVars , datasource_id : int , api : DatafoldAPI ) -> None :
263+ def _cloud_diff (diff_vars : TDiffVars , datasource_id : int , api : DatafoldAPI ) -> None :
254264 diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
255265 payload = TCloudApiDataDiff (
256266 data_source1_id = datasource_id ,
@@ -260,6 +270,8 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> No
260270 pk_columns = diff_vars .primary_keys ,
261271 filter1 = diff_vars .where_filter ,
262272 filter2 = diff_vars .where_filter ,
273+ include_columns = diff_vars .include_columns ,
274+ exclude_columns = diff_vars .exclude_columns ,
263275 )
264276
265277 if is_tracking_enabled ():
0 commit comments