"""WebSocket Connection Manager""" import json import asyncio from typing import Dict, Set, Optional from datetime import datetime from fastapi import WebSocket import redis.asyncio as redis from app.core.config import settings class ConnectionManager: """Manages WebSocket connections""" def __init__(self): self.active_connections: Dict[str, Set[WebSocket]] = {} # user_id -> connections self.redis_client: Optional[redis.Redis] = None async def connect(self, websocket: WebSocket, user_id: str): await websocket.accept() if user_id not in self.active_connections: self.active_connections[user_id] = set() self.active_connections[user_id].add(websocket) if self.redis_client is None: redis_url = settings.REDIS_URL if redis_url.startswith("redis://"): self.redis_client = redis.from_url(redis_url, decode_responses=True) else: self.redis_client = redis.Redis( host=settings.REDIS_SERVER, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True, ) def disconnect(self, websocket: WebSocket, user_id: str): if user_id in self.active_connections: self.active_connections[user_id].discard(websocket) if not self.active_connections[user_id]: del self.active_connections[user_id] async def send_personal_message(self, message: dict, user_id: str): if user_id in self.active_connections: for connection in self.active_connections[user_id]: try: await connection.send_json(message) except Exception: pass async def broadcast(self, message: dict, channel: str = "all"): if channel == "all": for user_id in self.active_connections: await self.send_personal_message(message, user_id) else: await self.send_personal_message(message, channel) async def close_all(self): for user_id in self.active_connections: for connection in self.active_connections[user_id]: await connection.close() self.active_connections.clear() manager = ConnectionManager() async def get_websocket_manager() -> ConnectionManager: return manager