from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import declarative_base from app.core.config import settings engine = create_async_engine( settings.DATABASE_URL, echo=settings.DEBUG if hasattr(settings, "DEBUG") else False, ) async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) Base = declarative_base() async def get_db() -> AsyncGenerator[AsyncSession, None]: async with async_session_factory() as session: try: yield session await session.commit() except Exception: await session.rollback() raise async def init_db(): import app.models.user # noqa: F401 import app.models.gpu_cluster # noqa: F401 import app.models.task # noqa: F401 import app.models.datasource # noqa: F401 async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all)