first commit
This commit is contained in:
162
backend/app/core/security.py
Normal file
162
backend/app/core/security.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
import redis
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
|
||||
oauth2_scheme = HTTPBearer()
|
||||
|
||||
|
||||
class _RedisClient:
|
||||
_client = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls):
|
||||
if cls._client is None:
|
||||
redis_url = settings.REDIS_URL
|
||||
if redis_url.startswith("redis://"):
|
||||
cls._client = redis.from_url(redis_url, decode_responses=True)
|
||||
else:
|
||||
cls._client = redis.Redis(
|
||||
host=settings.REDIS_SERVER,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
return cls._client
|
||||
|
||||
|
||||
redis_client = _RedisClient.get_client()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
if "sub" in to_encode:
|
||||
to_encode["sub"] = str(to_encode["sub"])
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
if "sub" in to_encode:
|
||||
to_encode["sub"] = str(to_encode["sub"])
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
if redis_client.sismember("blacklisted_tokens", token):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has been revoked",
|
||||
)
|
||||
payload = decode_token(token)
|
||||
if payload is None or payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, username, email, password_hash, role, is_active FROM users WHERE id = :id"
|
||||
),
|
||||
{"id": int(user_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is None or not row[5]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
user = User()
|
||||
user.id = row[0]
|
||||
user.username = row[1]
|
||||
user.email = row[2]
|
||||
user.password_hash = row[3]
|
||||
user.role = row[4]
|
||||
user.is_active = row[5]
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_refresh(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
payload = decode_token(token)
|
||||
if payload is None or payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, username, email, password_hash, role, is_active FROM users WHERE id = :id"
|
||||
),
|
||||
{"id": int(user_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is None or not row[5]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
user = User()
|
||||
user.id = row[0]
|
||||
user.username = row[1]
|
||||
user.email = row[2]
|
||||
user.password_hash = row[3]
|
||||
user.role = row[4]
|
||||
user.is_active = row[5]
|
||||
return user
|
||||
|
||||
|
||||
def blacklist_token(token: str) -> None:
|
||||
redis_client.sadd("blacklisted_tokens", token)
|
||||
Reference in New Issue
Block a user