77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
from typing import AsyncGenerator
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from sqlalchemy.orm import declarative_base
|
|
|
|
from app.core.config import settings
|
|
|
|
engine = create_async_engine(
|
|
settings.DATABASE_URL,
|
|
echo=settings.DEBUG if hasattr(settings, "DEBUG") else False,
|
|
)
|
|
|
|
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
async with async_session_factory() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
|
|
|
|
async def seed_default_datasources(session: AsyncSession):
|
|
from app.core.datasource_defaults import DEFAULT_DATASOURCES
|
|
from app.models.datasource import DataSource
|
|
|
|
for source, info in DEFAULT_DATASOURCES.items():
|
|
existing = await session.get(DataSource, info["id"])
|
|
if existing:
|
|
existing.name = info["name"]
|
|
existing.source = source
|
|
existing.module = info["module"]
|
|
existing.priority = info["priority"]
|
|
existing.frequency_minutes = info["frequency_minutes"]
|
|
existing.collector_class = source
|
|
if existing.config is None:
|
|
existing.config = "{}"
|
|
continue
|
|
|
|
session.add(
|
|
DataSource(
|
|
id=info["id"],
|
|
name=info["name"],
|
|
source=source,
|
|
module=info["module"],
|
|
priority=info["priority"],
|
|
frequency_minutes=info["frequency_minutes"],
|
|
collector_class=source,
|
|
config="{}",
|
|
is_active=True,
|
|
)
|
|
)
|
|
|
|
await session.commit()
|
|
|
|
|
|
async def init_db():
|
|
import app.models.user # noqa: F401
|
|
import app.models.gpu_cluster # noqa: F401
|
|
import app.models.task # noqa: F401
|
|
import app.models.datasource # noqa: F401
|
|
import app.models.datasource_config # noqa: F401
|
|
import app.models.alert # noqa: F401
|
|
import app.models.collected_data # noqa: F401
|
|
import app.models.system_setting # noqa: F401
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async with async_session_factory() as session:
|
|
await seed_default_datasources(session)
|