From 5441c19c1b6de92595787e35b35af8d94918dca9 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 19:14:54 +0000 Subject: [PATCH] feat: enhance /broadcast with targeting options (regular/authorized) - Enhanced `/broadcast` command to support optional arguments: `authorized` and `regular`. - `authorized`: Broadcasts only to users in the authorized list. - `regular`: Broadcasts only to users NOT in the authorized list. - Default behavior remains broadcasting to all users. - Refactored `broadcast_message` utility: - Added `mode` parameter. - Implemented dynamic cursor selection. - Added robust retry logic for `FloodWait` errors. - Added safety check to prevent deleting authorized users on delivery failure. - Updated `Database` class with helper methods for counting and fetching regular/authorized users efficiently using `$nin`. --- .python-version | 2 +- Thunder/bot/plugins/admin.py | 10 +- Thunder/utils/broadcast.py | 310 ++++++++++++------------ Thunder/utils/database.py | 458 +++++++++++++++++++---------------- 4 files changed, 414 insertions(+), 366 deletions(-) diff --git a/.python-version b/.python-version index 3a4f41e..24ee5b1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.13 \ No newline at end of file +3.13 diff --git a/Thunder/bot/plugins/admin.py b/Thunder/bot/plugins/admin.py index 331ce6f..c714784 100644 --- a/Thunder/bot/plugins/admin.py +++ b/Thunder/bot/plugins/admin.py @@ -65,7 +65,15 @@ async def get_total_users(client: Client, message: Message): @StreamBot.on_message(filters.command("broadcast") & owner_filter) async def broadcast_handler(client: Client, message: Message): - await broadcast_message(client, message) + mode = "all" + if len(message.command) > 1: + arg = message.command[1].lower() + if arg == "authorized": + mode = "authorized" + elif arg == "regular": + mode = "regular" + + await broadcast_message(client, message, mode=mode) @StreamBot.on_message(filters.command("status") & owner_filter) diff --git a/Thunder/utils/broadcast.py b/Thunder/utils/broadcast.py index 2518dee..2476929 100644 --- a/Thunder/utils/broadcast.py +++ b/Thunder/utils/broadcast.py @@ -1,150 +1,160 @@ -# Thunder/utils/broadcast.py - -import asyncio -import os -import time - -from pyrogram.client import Client -from pyrogram.enums import ParseMode -from pyrogram.errors import (ChatWriteForbidden, FloodWait, PeerIdInvalid, UserDeactivated, - UserIsBlocked, ChannelInvalid, InputUserDeactivated) -from pyrogram.types import (InlineKeyboardButton, InlineKeyboardMarkup, - Message) - -from Thunder.utils.database import db -from Thunder.utils.logger import logger -from Thunder.utils.messages import ( - MSG_INVALID_BROADCAST_CMD, - MSG_BROADCAST_START, - MSG_BUTTON_CANCEL_BROADCAST, - MSG_BROADCAST_COMPLETE -) -from Thunder.utils.time_format import get_readable_time - - -broadcast_ids = {} - -async def broadcast_message(client: Client, message: Message): - if not message.reply_to_message: - try: - await message.reply_text(MSG_INVALID_BROADCAST_CMD) - except FloodWait as e: - await asyncio.sleep(e.value) - await message.reply_text(MSG_INVALID_BROADCAST_CMD) - return - - broadcast_id = os.urandom(3).hex() - stats = {"total": 0, "success": 0, "failed": 0, "deleted": 0, "cancelled": False} - broadcast_ids[broadcast_id] = stats - - try: - status_msg = await message.reply_text( - MSG_BROADCAST_START, - reply_markup=InlineKeyboardMarkup([[ - InlineKeyboardButton(MSG_BUTTON_CANCEL_BROADCAST, callback_data=f"cancel_{broadcast_id}") - ]]) - ) - except FloodWait as e: - await asyncio.sleep(e.value) - status_msg = await message.reply_text( - MSG_BROADCAST_START, - reply_markup=InlineKeyboardMarkup([[ - InlineKeyboardButton(MSG_BUTTON_CANCEL_BROADCAST, callback_data=f"cancel_{broadcast_id}") - ]]) - ) - - start_time = time.time() - stats["total"] = await db.total_users_count() - - async def do_broadcast(): - async for user in db.get_all_users(): - if stats["cancelled"]: - break - try: - try: - result = await message.reply_to_message.copy(user['id']) - if result: - stats["success"] += 1 - else: - stats["failed"] += 1 - except (UserDeactivated, UserIsBlocked, PeerIdInvalid, ChatWriteForbidden, ChannelInvalid, InputUserDeactivated) as e: - if isinstance(e, ChannelInvalid): - recipient_type = "Channel" - reason = "invalid channel" - elif isinstance(e, InputUserDeactivated): - recipient_type = "User" - reason = "deactivated account" - elif isinstance(e, UserIsBlocked): - recipient_type = "User" - reason = "blocked the bot" - elif isinstance(e, UserDeactivated): - recipient_type = "User" - reason = "deactivated account" - elif isinstance(e, PeerIdInvalid): - recipient_type = "Recipient" - reason = "invalid ID" - elif isinstance(e, ChatWriteForbidden): - recipient_type = "Chat" - reason = "write forbidden" - else: - recipient_type = "Recipient" - reason = f"error: {type(e).__name__}" - - logger.warning(f"{recipient_type} {user['id']} removed due to {reason}") - - await db.delete_user(user['id']) - stats["deleted"] += 1 - continue - - except FloodWait as e: - await asyncio.sleep(e.value) - try: - result = await message.reply_to_message.copy(user['id']) - if result: - stats["success"] += 1 - else: - stats["failed"] += 1 - except FloodWait as e2: - await asyncio.sleep(e2.value) - result = await message.reply_to_message.copy(user['id']) - if result: - stats["success"] += 1 - else: - stats["failed"] += 1 - except Exception as e: - logger.error(f"Error copying message to user {user['id']}: {e}", exc_info=True) - stats["failed"] += 1 - - try: - await status_msg.delete() - except FloodWait as e: - await asyncio.sleep(e.value) - await status_msg.delete() - - try: - await message.reply_text( - MSG_BROADCAST_COMPLETE.format( - elapsed_time=get_readable_time(int(time.time() - start_time)), - total_users=stats["total"], - successes=stats["success"], - failures=stats["failed"], - deleted_accounts=stats["deleted"] - ), - parse_mode=ParseMode.MARKDOWN - ) - except FloodWait as e: - await asyncio.sleep(e.value) - await message.reply_text( - MSG_BROADCAST_COMPLETE.format( - elapsed_time=get_readable_time(int(time.time() - start_time)), - total_users=stats["total"], - successes=stats["success"], - failures=stats["failed"], - deleted_accounts=stats["deleted"] - ), - parse_mode=ParseMode.MARKDOWN - ) - - del broadcast_ids[broadcast_id] - - asyncio.create_task(do_broadcast()) +# Thunder/utils/broadcast.py + +import asyncio +import os +import time + +from pyrogram.client import Client +from pyrogram.enums import ParseMode +from pyrogram.errors import (ChatWriteForbidden, FloodWait, PeerIdInvalid, UserDeactivated, + UserIsBlocked, ChannelInvalid, InputUserDeactivated) +from pyrogram.types import (InlineKeyboardButton, InlineKeyboardMarkup, + Message) + +from Thunder.utils.database import db +from Thunder.utils.logger import logger +from Thunder.utils.messages import ( + MSG_INVALID_BROADCAST_CMD, + MSG_BROADCAST_START, + MSG_BUTTON_CANCEL_BROADCAST, + MSG_BROADCAST_COMPLETE +) +from Thunder.utils.time_format import get_readable_time + + +broadcast_ids = {} + +async def broadcast_message(client: Client, message: Message, mode: str = "all"): + if not message.reply_to_message: + try: + await message.reply_text(MSG_INVALID_BROADCAST_CMD) + except FloodWait as e: + await asyncio.sleep(e.value) + await message.reply_text(MSG_INVALID_BROADCAST_CMD) + return + + broadcast_id = os.urandom(3).hex() + stats = {"total": 0, "success": 0, "failed": 0, "deleted": 0, "cancelled": False} + broadcast_ids[broadcast_id] = stats + + try: + status_msg = await message.reply_text( + MSG_BROADCAST_START, + reply_markup=InlineKeyboardMarkup([[ + InlineKeyboardButton(MSG_BUTTON_CANCEL_BROADCAST, callback_data=f"cancel_{broadcast_id}") + ]]) + ) + except FloodWait as e: + await asyncio.sleep(e.value) + status_msg = await message.reply_text( + MSG_BROADCAST_START, + reply_markup=InlineKeyboardMarkup([[ + InlineKeyboardButton(MSG_BUTTON_CANCEL_BROADCAST, callback_data=f"cancel_{broadcast_id}") + ]]) + ) + + start_time = time.time() + + if mode == "authorized": + stats["total"] = await db.get_authorized_users_count() + cursor = db.get_authorized_users_cursor() + elif mode == "regular": + stats["total"] = await db.get_regular_users_count() + cursor = await db.get_regular_users_cursor() + else: + stats["total"] = await db.total_users_count() + cursor = db.get_all_users() + + async def do_broadcast(): + async for user in cursor: + if stats["cancelled"]: + break + + user_id = user.get('id') or user.get('user_id') + if not user_id: + logger.warning(f"Skipping user with no ID: {user}") + continue + + try: + msg = None + # Retry loop for FloodWait + for _ in range(3): + try: + msg = await message.reply_to_message.copy(user_id) + break + except FloodWait as e: + await asyncio.sleep(e.value) + except Exception: + raise # Raise other exceptions to be caught below + + if msg: + stats["success"] += 1 + else: + stats["failed"] += 1 + + except (UserDeactivated, UserIsBlocked, PeerIdInvalid, ChatWriteForbidden, ChannelInvalid, InputUserDeactivated) as e: + if isinstance(e, ChannelInvalid): + recipient_type = "Channel" + reason = "invalid channel" + elif isinstance(e, InputUserDeactivated): + recipient_type = "User" + reason = "deactivated account" + elif isinstance(e, UserIsBlocked): + recipient_type = "User" + reason = "blocked the bot" + elif isinstance(e, UserDeactivated): + recipient_type = "User" + reason = "deactivated account" + elif isinstance(e, PeerIdInvalid): + recipient_type = "Recipient" + reason = "invalid ID" + elif isinstance(e, ChatWriteForbidden): + recipient_type = "Chat" + reason = "write forbidden" + else: + recipient_type = "Recipient" + reason = f"error: {type(e).__name__}" + + logger.warning(f"{recipient_type} {user_id} removed due to {reason}") + + if mode != "authorized": + await db.delete_user(user_id) + stats["deleted"] += 1 + + except Exception as e: + logger.error(f"Error copying message to user {user_id}: {e}", exc_info=True) + stats["failed"] += 1 + + try: + await status_msg.delete() + except FloodWait as e: + await asyncio.sleep(e.value) + await status_msg.delete() + + try: + await message.reply_text( + MSG_BROADCAST_COMPLETE.format( + elapsed_time=get_readable_time(int(time.time() - start_time)), + total_users=stats["total"], + successes=stats["success"], + failures=stats["failed"], + deleted_accounts=stats["deleted"] + ), + parse_mode=ParseMode.MARKDOWN + ) + except FloodWait as e: + await asyncio.sleep(e.value) + await message.reply_text( + MSG_BROADCAST_COMPLETE.format( + elapsed_time=get_readable_time(int(time.time() - start_time)), + total_users=stats["total"], + successes=stats["success"], + failures=stats["failed"], + deleted_accounts=stats["deleted"] + ), + parse_mode=ParseMode.MARKDOWN + ) + + del broadcast_ids[broadcast_id] + + asyncio.create_task(do_broadcast()) diff --git a/Thunder/utils/database.py b/Thunder/utils/database.py index 6d32f95..7873e78 100644 --- a/Thunder/utils/database.py +++ b/Thunder/utils/database.py @@ -1,214 +1,244 @@ -# Thunder/utils/database.py - -import datetime -from typing import Optional, Dict, Any -from pymongo import AsyncMongoClient -from pymongo.asynchronous.collection import AsyncCollection -from Thunder.vars import Var -from Thunder.utils.logger import logger - -class Database: - def __init__(self, uri: str, database_name: str, *args, **kwargs): - self._client = AsyncMongoClient(uri, *args, **kwargs) - self.db = self._client[database_name] - self.col: AsyncCollection = self.db.users - self.banned_users_col: AsyncCollection = self.db.banned_users - self.banned_channels_col: AsyncCollection = self.db.banned_channels - self.token_col: AsyncCollection = self.db.tokens - self.authorized_users_col: AsyncCollection = self.db.authorized_users - self.restart_message_col: AsyncCollection = self.db.restart_message - - async def ensure_indexes(self): - try: - await self.banned_users_col.create_index("user_id", unique=True) - await self.banned_channels_col.create_index("channel_id", unique=True) - await self.token_col.create_index("token", unique=True) - await self.authorized_users_col.create_index("user_id", unique=True) - await self.col.create_index("id", unique=True) - await self.token_col.create_index("expires_at", expireAfterSeconds=0) - await self.token_col.create_index("activated") - await self.restart_message_col.create_index("message_id", unique=True) - await self.restart_message_col.create_index("timestamp", expireAfterSeconds=3600) - - logger.debug("Database indexes ensured.") - except Exception as e: - logger.error(f"Error in ensure_indexes: {e}", exc_info=True) - raise - - def new_user(self, user_id: int) -> dict: - try: - return { - 'id': user_id, - 'join_date': datetime.datetime.utcnow() - } - except Exception as e: - logger.error(f"Error in new_user for user {user_id}: {e}", exc_info=True) - raise - - async def add_user(self, user_id: int): - try: - if not await self.is_user_exist(user_id): - await self.col.insert_one(self.new_user(user_id)) - logger.debug(f"Added new user {user_id} to database.") - except Exception as e: - logger.error(f"Error in add_user for user {user_id}: {e}", exc_info=True) - raise - - - async def is_user_exist(self, user_id: int) -> bool: - try: - user = await self.col.find_one({'id': user_id}, {'_id': 1}) - return bool(user) - except Exception as e: - logger.error(f"Error in is_user_exist for user {user_id}: {e}", exc_info=True) - raise - - async def total_users_count(self) -> int: - try: - return await self.col.count_documents({}) - except Exception as e: - logger.error(f"Error in total_users_count: {e}", exc_info=True) - return 0 - - def get_all_users(self): - try: - return self.col.find({}) - except Exception as e: - logger.error(f"Error in get_all_users: {e}", exc_info=True) - return self.col.find({"_id": {"$exists": False}}) - - async def delete_user(self, user_id: int): - try: - await self.col.delete_one({'id': user_id}) - logger.debug(f"Deleted user {user_id}.") - except Exception as e: - logger.error(f"Error in delete_user for user {user_id}: {e}", exc_info=True) - raise - - - async def add_banned_user( - self, user_id: int, banned_by: Optional[int] = None, - reason: Optional[str] = None - ): - try: - ban_data = { - "user_id": user_id, - "banned_at": datetime.datetime.utcnow(), - "banned_by": banned_by, - "reason": reason - } - await self.banned_users_col.update_one( - {"user_id": user_id}, - {"$set": ban_data}, - upsert=True - ) - logger.debug(f"Added/Updated banned user {user_id}. Reason: {reason}") - except Exception as e: - logger.error(f"Error in add_banned_user for user {user_id}: {e}", exc_info=True) - raise - - async def remove_banned_user(self, user_id: int) -> bool: - try: - result = await self.banned_users_col.delete_one({"user_id": user_id}) - if result.deleted_count > 0: - logger.debug(f"Removed banned user {user_id}.") - return True - return False - except Exception as e: - logger.error(f"Error in remove_banned_user for user {user_id}: {e}", exc_info=True) - return False - - async def is_user_banned(self, user_id: int) -> Optional[Dict[str, Any]]: - try: - return await self.banned_users_col.find_one({"user_id": user_id}) - except Exception as e: - logger.error(f"Error in is_user_banned for user {user_id}: {e}", exc_info=True) - return None - - async def add_banned_channel( - self, channel_id: int, banned_by: Optional[int] = None, - reason: Optional[str] = None - ): - try: - ban_data = { - "channel_id": channel_id, - "banned_at": datetime.datetime.utcnow(), - "banned_by": banned_by, - "reason": reason - } - await self.banned_channels_col.update_one( - {"channel_id": channel_id}, - {"$set": ban_data}, - upsert=True - ) - logger.debug(f"Added/Updated banned channel {channel_id}. Reason: {reason}") - except Exception as e: - logger.error(f"Error in add_banned_channel for channel {channel_id}: {e}", exc_info=True) - raise - - async def remove_banned_channel(self, channel_id: int) -> bool: - try: - result = await self.banned_channels_col.delete_one({"channel_id": channel_id}) - if result.deleted_count > 0: - logger.debug(f"Removed banned channel {channel_id}.") - return True - return False - except Exception as e: - logger.error(f"Error in remove_banned_channel for channel {channel_id}: {e}", exc_info=True) - return False - - async def is_channel_banned(self, channel_id: int) -> Optional[Dict[str, Any]]: - try: - return await self.banned_channels_col.find_one({"channel_id": channel_id}) - except Exception as e: - logger.error(f"Error in is_channel_banned for channel {channel_id}: {e}", exc_info=True) - return None - - async def save_main_token(self, user_id: int, token_value: str, expires_at: datetime.datetime, created_at: datetime.datetime, activated: bool) -> None: - try: - await self.token_col.update_one( - {"user_id": user_id, "token": token_value}, - {"$set": { - "expires_at": expires_at, - "created_at": created_at, - "activated": activated - } - }, - upsert=True - ) - logger.debug(f"Saved main token {token_value} for user {user_id} with activated status {activated}.") - except Exception as e: - logger.error(f"Error saving main token for user {user_id}: {e}", exc_info=True) - raise - - - async def add_restart_message(self, message_id: int, chat_id: int) -> None: - try: - await self.restart_message_col.insert_one({ - "message_id": message_id, - "chat_id": chat_id, - "timestamp": datetime.datetime.utcnow() - }) - logger.debug(f"Added restart message {message_id} for chat {chat_id}.") - except Exception as e: - logger.error(f"Error adding restart message {message_id}: {e}", exc_info=True) - - async def get_restart_message(self) -> Optional[Dict[str, Any]]: - try: - return await self.restart_message_col.find_one(sort=[("timestamp", -1)]) - except Exception as e: - logger.error(f"Error getting restart message: {e}", exc_info=True) - return None - - async def delete_restart_message(self, message_id: int) -> None: - try: - await self.restart_message_col.delete_one({"message_id": message_id}) - logger.debug(f"Deleted restart message {message_id}.") - except Exception as e: - logger.error(f"Error deleting restart message {message_id}: {e}", exc_info=True) - - async def close(self): - if self._client: - await self._client.close() - -db = Database(Var.DATABASE_URL, Var.NAME) +# Thunder/utils/database.py + +import datetime +from typing import Optional, Dict, Any +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection +from Thunder.vars import Var +from Thunder.utils.logger import logger + +class Database: + def __init__(self, uri: str, database_name: str, *args, **kwargs): + self._client = AsyncMongoClient(uri, *args, **kwargs) + self.db = self._client[database_name] + self.col: AsyncCollection = self.db.users + self.banned_users_col: AsyncCollection = self.db.banned_users + self.banned_channels_col: AsyncCollection = self.db.banned_channels + self.token_col: AsyncCollection = self.db.tokens + self.authorized_users_col: AsyncCollection = self.db.authorized_users + self.restart_message_col: AsyncCollection = self.db.restart_message + + async def ensure_indexes(self): + try: + await self.banned_users_col.create_index("user_id", unique=True) + await self.banned_channels_col.create_index("channel_id", unique=True) + await self.token_col.create_index("token", unique=True) + await self.authorized_users_col.create_index("user_id", unique=True) + await self.col.create_index("id", unique=True) + await self.token_col.create_index("expires_at", expireAfterSeconds=0) + await self.token_col.create_index("activated") + await self.restart_message_col.create_index("message_id", unique=True) + await self.restart_message_col.create_index("timestamp", expireAfterSeconds=3600) + + logger.debug("Database indexes ensured.") + except Exception as e: + logger.error(f"Error in ensure_indexes: {e}", exc_info=True) + raise + + def new_user(self, user_id: int) -> dict: + try: + return { + 'id': user_id, + 'join_date': datetime.datetime.utcnow() + } + except Exception as e: + logger.error(f"Error in new_user for user {user_id}: {e}", exc_info=True) + raise + + async def add_user(self, user_id: int): + try: + if not await self.is_user_exist(user_id): + await self.col.insert_one(self.new_user(user_id)) + logger.debug(f"Added new user {user_id} to database.") + except Exception as e: + logger.error(f"Error in add_user for user {user_id}: {e}", exc_info=True) + raise + + + async def is_user_exist(self, user_id: int) -> bool: + try: + user = await self.col.find_one({'id': user_id}, {'_id': 1}) + return bool(user) + except Exception as e: + logger.error(f"Error in is_user_exist for user {user_id}: {e}", exc_info=True) + raise + + async def total_users_count(self) -> int: + try: + return await self.col.count_documents({}) + except Exception as e: + logger.error(f"Error in total_users_count: {e}", exc_info=True) + return 0 + + async def get_authorized_users_count(self) -> int: + try: + return await self.authorized_users_col.count_documents({}) + except Exception as e: + logger.error(f"Error in get_authorized_users_count: {e}", exc_info=True) + return 0 + + async def get_regular_users_count(self) -> int: + try: + auth_ids = await self.authorized_users_col.distinct("user_id") + return await self.col.count_documents({"id": {"$nin": auth_ids}}) + except Exception as e: + logger.error(f"Error in get_regular_users_count: {e}", exc_info=True) + return 0 + + def get_all_users(self): + try: + return self.col.find({}) + except Exception as e: + logger.error(f"Error in get_all_users: {e}", exc_info=True) + return self.col.find({"_id": {"$exists": False}}) + + def get_authorized_users_cursor(self): + try: + return self.authorized_users_col.find({}) + except Exception as e: + logger.error(f"Error in get_authorized_users_cursor: {e}", exc_info=True) + return self.authorized_users_col.find({"_id": {"$exists": False}}) + + async def get_regular_users_cursor(self): + try: + auth_ids = await self.authorized_users_col.distinct("user_id") + return self.col.find({"id": {"$nin": auth_ids}}) + except Exception as e: + logger.error(f"Error in get_regular_users_cursor: {e}", exc_info=True) + return self.col.find({"_id": {"$exists": False}}) + + async def delete_user(self, user_id: int): + try: + await self.col.delete_one({'id': user_id}) + logger.debug(f"Deleted user {user_id}.") + except Exception as e: + logger.error(f"Error in delete_user for user {user_id}: {e}", exc_info=True) + raise + + + async def add_banned_user( + self, user_id: int, banned_by: Optional[int] = None, + reason: Optional[str] = None + ): + try: + ban_data = { + "user_id": user_id, + "banned_at": datetime.datetime.utcnow(), + "banned_by": banned_by, + "reason": reason + } + await self.banned_users_col.update_one( + {"user_id": user_id}, + {"$set": ban_data}, + upsert=True + ) + logger.debug(f"Added/Updated banned user {user_id}. Reason: {reason}") + except Exception as e: + logger.error(f"Error in add_banned_user for user {user_id}: {e}", exc_info=True) + raise + + async def remove_banned_user(self, user_id: int) -> bool: + try: + result = await self.banned_users_col.delete_one({"user_id": user_id}) + if result.deleted_count > 0: + logger.debug(f"Removed banned user {user_id}.") + return True + return False + except Exception as e: + logger.error(f"Error in remove_banned_user for user {user_id}: {e}", exc_info=True) + return False + + async def is_user_banned(self, user_id: int) -> Optional[Dict[str, Any]]: + try: + return await self.banned_users_col.find_one({"user_id": user_id}) + except Exception as e: + logger.error(f"Error in is_user_banned for user {user_id}: {e}", exc_info=True) + return None + + async def add_banned_channel( + self, channel_id: int, banned_by: Optional[int] = None, + reason: Optional[str] = None + ): + try: + ban_data = { + "channel_id": channel_id, + "banned_at": datetime.datetime.utcnow(), + "banned_by": banned_by, + "reason": reason + } + await self.banned_channels_col.update_one( + {"channel_id": channel_id}, + {"$set": ban_data}, + upsert=True + ) + logger.debug(f"Added/Updated banned channel {channel_id}. Reason: {reason}") + except Exception as e: + logger.error(f"Error in add_banned_channel for channel {channel_id}: {e}", exc_info=True) + raise + + async def remove_banned_channel(self, channel_id: int) -> bool: + try: + result = await self.banned_channels_col.delete_one({"channel_id": channel_id}) + if result.deleted_count > 0: + logger.debug(f"Removed banned channel {channel_id}.") + return True + return False + except Exception as e: + logger.error(f"Error in remove_banned_channel for channel {channel_id}: {e}", exc_info=True) + return False + + async def is_channel_banned(self, channel_id: int) -> Optional[Dict[str, Any]]: + try: + return await self.banned_channels_col.find_one({"channel_id": channel_id}) + except Exception as e: + logger.error(f"Error in is_channel_banned for channel {channel_id}: {e}", exc_info=True) + return None + + async def save_main_token(self, user_id: int, token_value: str, expires_at: datetime.datetime, created_at: datetime.datetime, activated: bool) -> None: + try: + await self.token_col.update_one( + {"user_id": user_id, "token": token_value}, + {"$set": { + "expires_at": expires_at, + "created_at": created_at, + "activated": activated + } + }, + upsert=True + ) + logger.debug(f"Saved main token {token_value} for user {user_id} with activated status {activated}.") + except Exception as e: + logger.error(f"Error saving main token for user {user_id}: {e}", exc_info=True) + raise + + + async def add_restart_message(self, message_id: int, chat_id: int) -> None: + try: + await self.restart_message_col.insert_one({ + "message_id": message_id, + "chat_id": chat_id, + "timestamp": datetime.datetime.utcnow() + }) + logger.debug(f"Added restart message {message_id} for chat {chat_id}.") + except Exception as e: + logger.error(f"Error adding restart message {message_id}: {e}", exc_info=True) + + async def get_restart_message(self) -> Optional[Dict[str, Any]]: + try: + return await self.restart_message_col.find_one(sort=[("timestamp", -1)]) + except Exception as e: + logger.error(f"Error getting restart message: {e}", exc_info=True) + return None + + async def delete_restart_message(self, message_id: int) -> None: + try: + await self.restart_message_col.delete_one({"message_id": message_id}) + logger.debug(f"Deleted restart message {message_id}.") + except Exception as e: + logger.error(f"Error deleting restart message {message_id}: {e}", exc_info=True) + + async def close(self): + if self._client: + await self._client.close() + +db = Database(Var.DATABASE_URL, Var.NAME)