Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 64 additions & 53 deletions app/core/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def compare_rules(self, query: str) -> Dict:
section_comparisons, rule1_unique, rule2_unique, rule1, rule2, query
)

return {
result = {
"rule1": rule1,
"rule2": rule2,
"topic": topic,
Expand All @@ -464,6 +464,9 @@ def compare_rules(self, query: str) -> Dict:
}
}

response = json.dumps(result, default=self.clean_numpy, indent=2)
return response

# Include existing methods from original class
def parse_comparison_query(self, query: str) -> Tuple[Dict, Dict, str]:
"""Parse user query to extract rules and topic (from original class)"""
Expand Down Expand Up @@ -594,55 +597,63 @@ def semantic_similarity_search(self, query_embedding: np.ndarray,

return relevant_chunks


# Example usage
if __name__ == "__main__":
load_dotenv()
import os

faiss_index_path = "./rag_data/faiss.index"
metadata_path = "./rag_data/faiss_metadata.json"

comparator = SectionBySectionRuleComparator(
faiss_index_path=faiss_index_path,
metadata_path=metadata_path,
api_key=os.getenv("OPENAI_API_KEY")
)

# Test query
query = "Compare MPFS fee schedule for 2023 final rule and 2024 final rule"

print(f"Testing section-by-section comparison: {query}")
result = comparator.compare_rules(query)
print (result)

# if "error" in result:
# print(f"Error: {result['error']}")
# else:
# # print(f"\n{'='*80}")
# # print("SECTION-BY-SECTION COMPARISON RESULTS")
# # print('='*80)

# # print(f"\nRules Compared:")
# # print(f"Rule 1: {result['rule1']['program']} {result['rule1']['year']} {result['rule1']['rule_type']}")
# # print(f"Rule 2: {result['rule2']['program']} {result['rule2']['year']} {result['rule2']['rule_type']}")

# # #print(f"\nStats:")
# # stats = result['stats']
# # print(f"- Sections compared: {stats['total_sections_compared']}")
# # print(f"- Rule 1 unique sections: {stats['rule1_unique_sections']}")
# # print(f"- Rule 2 unique sections: {stats['rule2_unique_sections']}")

# print(f"\n{'='*60}")
# print("FINAL EXECUTIVE SUMMARY")
# print('='*60)
# print(result['final_summary'])

# print(f"\n{'='*60}")
# print("DETAILED SECTION COMPARISONS")
# print('='*60)
# for i, comp in enumerate(result['section_comparisons'][:10], 1): # Show first 5
# if not comp.get('error'):
# print(f"\n{i}. {comp['rule1_section']} <-> {comp['rule2_section']}")
# #print(f" Similarity: {comp['similarity_score']:.2f}")
# print(f" {comp['comparison']}")
def clean_numpy(self, obj):
if isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, (set, tuple)):
return list(obj)
return str(obj)

# # Example usage
# if __name__ == "__main__":
# load_dotenv()
# import os
#
# faiss_index_path = "./rag_data/faiss.index"
# metadata_path = "./rag_data/faiss_metadata.json"
#
# comparator = SectionBySectionRuleComparator(
# faiss_index_path=faiss_index_path,
# metadata_path=metadata_path,
# api_key=os.getenv("OPENAI_API_KEY")
# )
#
# # Test query
# query = "Compare MPFS fee schedule for 2023 final rule and 2024 final rule"
#
# print(f"Testing section-by-section comparison: {query}")
# result = comparator.compare_rules(query)
# print (result)
#
# # if "error" in result:
# # print(f"Error: {result['error']}")
# # else:
# # # print(f"\n{'='*80}")
# # # print("SECTION-BY-SECTION COMPARISON RESULTS")
# # # print('='*80)
#
# # # print(f"\nRules Compared:")
# # # print(f"Rule 1: {result['rule1']['program']} {result['rule1']['year']} {result['rule1']['rule_type']}")
# # # print(f"Rule 2: {result['rule2']['program']} {result['rule2']['year']} {result['rule2']['rule_type']}")
#
# # # #print(f"\nStats:")
# # # stats = result['stats']
# # # print(f"- Sections compared: {stats['total_sections_compared']}")
# # # print(f"- Rule 1 unique sections: {stats['rule1_unique_sections']}")
# # # print(f"- Rule 2 unique sections: {stats['rule2_unique_sections']}")
#
# # print(f"\n{'='*60}")
# # print("FINAL EXECUTIVE SUMMARY")
# # print('='*60)
# # print(result['final_summary'])
#
# # print(f"\n{'='*60}")
# # print("DETAILED SECTION COMPARISONS")
# # print('='*60)
# # for i, comp in enumerate(result['section_comparisons'][:10], 1): # Show first 5
# # if not comp.get('error'):
# # print(f"\n{i}. {comp['rule1_section']} <-> {comp['rule2_section']}")
# # #print(f" Similarity: {comp['similarity_score']:.2f}")
# # print(f" {comp['comparison']}")
24 changes: 19 additions & 5 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import requests

from app.core import summarizer
from app.core.compare import SectionBySectionRuleComparator
from app.core.summarizer import SummaryGenerator
from .core.search import ChatSearchService
from .config import config
Expand Down Expand Up @@ -59,11 +60,17 @@ def create_app() -> Flask:
openai_api_key=api_key
)

comparator = SectionBySectionRuleComparator(
faiss_index_path=config.faiss_index_path,
metadata_path=config.faiss_metadata_path,
api_key=api_key
)

# Register error handlers
register_error_handlers(app)

# Register routes
register_routes(app, chat_service)
register_routes(app, chat_service, summarizer, comparator)

return app

Expand Down Expand Up @@ -407,7 +414,7 @@ def handle_bad_request(error: BadRequest) -> tuple[Dict[str, str], int]:
return jsonify({"error": str(error)}), 400


def register_routes(app: Flask, chat_service: ChatSearchService) -> None:
def register_routes(app: Flask, chat_service: ChatSearchService, summarizer: SummaryGenerator, comparator: SectionBySectionRuleComparator) -> None:
"""
Register routes for the Flask application.

Expand Down Expand Up @@ -549,8 +556,6 @@ def summarize() -> tuple[Dict[str, str], int]:
logger.error(f"Error in summarize endpoint: {str(e)}")
return jsonify({"error": str(e)}), 400



@app.route("/api/get-summary", methods=["POST"])
def api_get_summary() -> tuple[Dict[str, Any], int]:
"""
Expand Down Expand Up @@ -641,7 +646,16 @@ def get_federal_register_info(doc_number: str) -> tuple[Dict[str, Any], int]:
logger.error(f"Error in federal-register endpoint: {str(e)}")
return jsonify({"error": str(e)}), 500


@app.route("/api/compare", methods=["POST"])
def compare() -> tuple[Dict[str, Any], int]:
try:
data = validate_json_request(required_fields=["message"])
query = data.get("message")
response = comparator.compare_rules(query)
return response#jsonify({"response": response})
except Exception as e:
logger.error(f"Error in summarize endpoint: {str(e)}")
return jsonify({"error": str(e)}), 500

def main() -> None:
"""Main entry point for the Flask application."""
Expand Down