From c218a45eacde499d94322129762ef952bb1e2590 Mon Sep 17 00:00:00 2001 From: SarveshSiras Date: Thu, 17 Jul 2025 18:15:52 -0400 Subject: [PATCH] Added code to return comparison result. --- app/core/compare.py | 117 ++++++++++++++++++++++++-------------------- app/main.py | 24 +++++++-- 2 files changed, 83 insertions(+), 58 deletions(-) diff --git a/app/core/compare.py b/app/core/compare.py index 346515f8..9de0d244 100644 --- a/app/core/compare.py +++ b/app/core/compare.py @@ -447,7 +447,7 @@ def compare_rules(self, query: str) -> Dict: section_comparisons, rule1_unique, rule2_unique, rule1, rule2, query ) - return { + result = { "rule1": rule1, "rule2": rule2, "topic": topic, @@ -464,6 +464,9 @@ def compare_rules(self, query: str) -> Dict: } } + response = json.dumps(result, default=self.clean_numpy, indent=2) + return response + # Include existing methods from original class def parse_comparison_query(self, query: str) -> Tuple[Dict, Dict, str]: """Parse user query to extract rules and topic (from original class)""" @@ -594,55 +597,63 @@ def semantic_similarity_search(self, query_embedding: np.ndarray, return relevant_chunks - -# Example usage -if __name__ == "__main__": - load_dotenv() - import os - - faiss_index_path = "./rag_data/faiss.index" - metadata_path = "./rag_data/faiss_metadata.json" - - comparator = SectionBySectionRuleComparator( - faiss_index_path=faiss_index_path, - metadata_path=metadata_path, - api_key=os.getenv("OPENAI_API_KEY") - ) - - # Test query - query = "Compare MPFS fee schedule for 2023 final rule and 2024 final rule" - - print(f"Testing section-by-section comparison: {query}") - result = comparator.compare_rules(query) - print (result) - - # if "error" in result: - # print(f"Error: {result['error']}") - # else: - # # print(f"\n{'='*80}") - # # print("SECTION-BY-SECTION COMPARISON RESULTS") - # # print('='*80) - - # # print(f"\nRules Compared:") - # # print(f"Rule 1: {result['rule1']['program']} {result['rule1']['year']} {result['rule1']['rule_type']}") - # # print(f"Rule 2: {result['rule2']['program']} {result['rule2']['year']} {result['rule2']['rule_type']}") - - # # #print(f"\nStats:") - # # stats = result['stats'] - # # print(f"- Sections compared: {stats['total_sections_compared']}") - # # print(f"- Rule 1 unique sections: {stats['rule1_unique_sections']}") - # # print(f"- Rule 2 unique sections: {stats['rule2_unique_sections']}") - - # print(f"\n{'='*60}") - # print("FINAL EXECUTIVE SUMMARY") - # print('='*60) - # print(result['final_summary']) - - # print(f"\n{'='*60}") - # print("DETAILED SECTION COMPARISONS") - # print('='*60) - # for i, comp in enumerate(result['section_comparisons'][:10], 1): # Show first 5 - # if not comp.get('error'): - # print(f"\n{i}. {comp['rule1_section']} <-> {comp['rule2_section']}") - # #print(f" Similarity: {comp['similarity_score']:.2f}") - # print(f" {comp['comparison']}") \ No newline at end of file + def clean_numpy(self, obj): + if isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, (set, tuple)): + return list(obj) + return str(obj) + +# # Example usage +# if __name__ == "__main__": +# load_dotenv() +# import os +# +# faiss_index_path = "./rag_data/faiss.index" +# metadata_path = "./rag_data/faiss_metadata.json" +# +# comparator = SectionBySectionRuleComparator( +# faiss_index_path=faiss_index_path, +# metadata_path=metadata_path, +# api_key=os.getenv("OPENAI_API_KEY") +# ) +# +# # Test query +# query = "Compare MPFS fee schedule for 2023 final rule and 2024 final rule" +# +# print(f"Testing section-by-section comparison: {query}") +# result = comparator.compare_rules(query) +# print (result) +# +# # if "error" in result: +# # print(f"Error: {result['error']}") +# # else: +# # # print(f"\n{'='*80}") +# # # print("SECTION-BY-SECTION COMPARISON RESULTS") +# # # print('='*80) +# +# # # print(f"\nRules Compared:") +# # # print(f"Rule 1: {result['rule1']['program']} {result['rule1']['year']} {result['rule1']['rule_type']}") +# # # print(f"Rule 2: {result['rule2']['program']} {result['rule2']['year']} {result['rule2']['rule_type']}") +# +# # # #print(f"\nStats:") +# # # stats = result['stats'] +# # # print(f"- Sections compared: {stats['total_sections_compared']}") +# # # print(f"- Rule 1 unique sections: {stats['rule1_unique_sections']}") +# # # print(f"- Rule 2 unique sections: {stats['rule2_unique_sections']}") +# +# # print(f"\n{'='*60}") +# # print("FINAL EXECUTIVE SUMMARY") +# # print('='*60) +# # print(result['final_summary']) +# +# # print(f"\n{'='*60}") +# # print("DETAILED SECTION COMPARISONS") +# # print('='*60) +# # for i, comp in enumerate(result['section_comparisons'][:10], 1): # Show first 5 +# # if not comp.get('error'): +# # print(f"\n{i}. {comp['rule1_section']} <-> {comp['rule2_section']}") +# # #print(f" Similarity: {comp['similarity_score']:.2f}") +# # print(f" {comp['comparison']}") \ No newline at end of file diff --git a/app/main.py b/app/main.py index 9cd3c400..e6b3f04a 100644 --- a/app/main.py +++ b/app/main.py @@ -15,6 +15,7 @@ import requests from app.core import summarizer +from app.core.compare import SectionBySectionRuleComparator from app.core.summarizer import SummaryGenerator from .core.search import ChatSearchService from .config import config @@ -59,11 +60,17 @@ def create_app() -> Flask: openai_api_key=api_key ) + comparator = SectionBySectionRuleComparator( + faiss_index_path=config.faiss_index_path, + metadata_path=config.faiss_metadata_path, + api_key=api_key + ) + # Register error handlers register_error_handlers(app) # Register routes - register_routes(app, chat_service) + register_routes(app, chat_service, summarizer, comparator) return app @@ -407,7 +414,7 @@ def handle_bad_request(error: BadRequest) -> tuple[Dict[str, str], int]: return jsonify({"error": str(error)}), 400 -def register_routes(app: Flask, chat_service: ChatSearchService) -> None: +def register_routes(app: Flask, chat_service: ChatSearchService, summarizer: SummaryGenerator, comparator: SectionBySectionRuleComparator) -> None: """ Register routes for the Flask application. @@ -549,8 +556,6 @@ def summarize() -> tuple[Dict[str, str], int]: logger.error(f"Error in summarize endpoint: {str(e)}") return jsonify({"error": str(e)}), 400 - - @app.route("/api/get-summary", methods=["POST"]) def api_get_summary() -> tuple[Dict[str, Any], int]: """ @@ -641,7 +646,16 @@ def get_federal_register_info(doc_number: str) -> tuple[Dict[str, Any], int]: logger.error(f"Error in federal-register endpoint: {str(e)}") return jsonify({"error": str(e)}), 500 - + @app.route("/api/compare", methods=["POST"]) + def compare() -> tuple[Dict[str, Any], int]: + try: + data = validate_json_request(required_fields=["message"]) + query = data.get("message") + response = comparator.compare_rules(query) + return response#jsonify({"response": response}) + except Exception as e: + logger.error(f"Error in summarize endpoint: {str(e)}") + return jsonify({"error": str(e)}), 500 def main() -> None: """Main entry point for the Flask application."""