from datetime import datetime, timedelta from typing import Optional import bcrypt import redis from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose import JWTError, jwt from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.db.session import get_db from app.models.user import User oauth2_scheme = HTTPBearer() class _RedisClient: _client = None @classmethod def get_client(cls): if cls._client is None: redis_url = settings.REDIS_URL if redis_url.startswith("redis://"): cls._client = redis.from_url(redis_url, decode_responses=True) else: cls._client = redis.Redis( host=settings.REDIS_SERVER, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True, ) return cls._client redis_client = _RedisClient.get_client() def verify_password(plain_password: str, hashed_password: str) -> bool: return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) def get_password_hash(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta elif settings.ACCESS_TOKEN_EXPIRE_MINUTES > 0: expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) else: expire = None if expire: to_encode.update({"exp": expire}) to_encode.update({"type": "access"}) if "sub" in to_encode: to_encode["sub"] = str(to_encode["sub"]) return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) def create_refresh_token(data: dict) -> str: to_encode = data.copy() if settings.REFRESH_TOKEN_EXPIRE_DAYS > 0: expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) to_encode.update({"exp": expire}) to_encode.update({"type": "refresh"}) if "sub" in to_encode: to_encode["sub"] = str(to_encode["sub"]) return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) def decode_token(token: str) -> Optional[dict]: try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) return payload except JWTError: return None async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db), ) -> User: token = credentials.credentials if redis_client.sismember("blacklisted_tokens", token): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has been revoked", ) payload = decode_token(token) if payload is None or payload.get("type") != "access": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", ) user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", ) result = await db.execute( text( "SELECT id, username, email, password_hash, role, is_active FROM users WHERE id = :id" ), {"id": int(user_id)}, ) row = result.fetchone() if row is None or not row[5]: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", ) user = User() user.id = row[0] user.username = row[1] user.email = row[2] user.password_hash = row[3] user.role = row[4] user.is_active = row[5] return user async def get_current_user_refresh( credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db), ) -> User: token = credentials.credentials payload = decode_token(token) if payload is None or payload.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", ) result = await db.execute( text( "SELECT id, username, email, password_hash, role, is_active FROM users WHERE id = :id" ), {"id": int(user_id)}, ) row = result.fetchone() if row is None or not row[5]: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", ) user = User() user.id = row[0] user.username = row[1] user.email = row[2] user.password_hash = row[3] user.role = row[4] user.is_active = row[5] return user def blacklist_token(token: str) -> None: redis_client.sadd("blacklisted_tokens", token)