first commit
This commit is contained in:
23
backend/.env.example
Normal file
23
backend/.env.example
Normal file
@@ -0,0 +1,23 @@
|
||||
# Database
|
||||
POSTGRES_SERVER=localhost
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=planet_db
|
||||
|
||||
# Redis
|
||||
REDIS_SERVER=localhost
|
||||
REDIS_PORT=6379
|
||||
|
||||
# Security
|
||||
SECRET_KEY=your-secret-key-change-in-production
|
||||
ALGORITHM=HS256
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# API
|
||||
API_V1_STR=/api/v1
|
||||
PROJECT_NAME="Intelligent Planet Plan"
|
||||
VERSION=1.0.0
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS=["http://localhost:3000", "http://localhost:8000"]
|
||||
19
backend/Dockerfile
Normal file
19
backend/Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
BIN
backend/app/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
backend/app/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/main.cpython-311.pyc
Normal file
BIN
backend/app/__pycache__/main.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/__pycache__/main.cpython-311.pyc
Normal file
BIN
backend/app/api/__pycache__/main.cpython-311.pyc
Normal file
Binary file not shown.
27
backend/app/api/main.py
Normal file
27
backend/app/api/main.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.v1 import (
|
||||
auth,
|
||||
users,
|
||||
datasource_config,
|
||||
datasources,
|
||||
tasks,
|
||||
dashboard,
|
||||
websocket,
|
||||
alerts,
|
||||
settings,
|
||||
collected_data,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(
|
||||
datasource_config.router, prefix="/datasources", tags=["datasource-config"]
|
||||
)
|
||||
api_router.include_router(datasources.router, prefix="/datasources", tags=["datasources"])
|
||||
api_router.include_router(collected_data.router, prefix="/collected", tags=["collected-data"])
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["dashboard"])
|
||||
api_router.include_router(alerts.router, prefix="/alerts", tags=["alerts"])
|
||||
api_router.include_router(settings.router, prefix="/settings", tags=["settings"])
|
||||
BIN
backend/app/api/v1/__pycache__/alerts.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/alerts.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/auth.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/auth.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/dashboard.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/dashboard.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/datasource_config.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/datasource_config.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/datasources.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/datasources.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/settings.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/settings.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/tasks.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/tasks.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/users.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/users.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/api/v1/__pycache__/websocket.cpython-311.pyc
Normal file
BIN
backend/app/api/v1/__pycache__/websocket.cpython-311.pyc
Normal file
Binary file not shown.
124
backend/app/api/v1/alerts.py
Normal file
124
backend/app/api/v1/alerts.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select, func, case
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import get_current_user
|
||||
from app.models.alert import Alert, AlertSeverity, AlertStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_alerts(
|
||||
severity: str = None,
|
||||
status: str = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Alert)
|
||||
|
||||
if severity:
|
||||
query = query.where(Alert.severity == AlertSeverity(severity))
|
||||
if status:
|
||||
query = query.where(Alert.status == AlertStatus(status))
|
||||
|
||||
query = query.order_by(
|
||||
case(
|
||||
(Alert.severity == AlertSeverity.CRITICAL, 1),
|
||||
(Alert.severity == AlertSeverity.WARNING, 2),
|
||||
(Alert.severity == AlertSeverity.INFO, 3),
|
||||
),
|
||||
Alert.created_at.desc(),
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
alerts = result.scalars().all()
|
||||
|
||||
total_query = select(func.count(Alert.id))
|
||||
if severity:
|
||||
total_query = total_query.where(Alert.severity == AlertSeverity(severity))
|
||||
if status:
|
||||
total_query = total_query.where(Alert.status == AlertStatus(status))
|
||||
total_result = await db.execute(total_query)
|
||||
total = total_result.scalar()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"data": [alert.to_dict() for alert in alerts],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{alert_id}/acknowledge")
|
||||
async def acknowledge_alert(
|
||||
alert_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Alert).where(Alert.id == alert_id))
|
||||
alert = result.scalar_one_or_none()
|
||||
|
||||
if not alert:
|
||||
return {"error": "Alert not found"}
|
||||
|
||||
alert.status = AlertStatus.ACKNOWLEDGED
|
||||
alert.acknowledged_by = current_user.id
|
||||
alert.acknowledged_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Alert acknowledged", "alert": alert.to_dict()}
|
||||
|
||||
|
||||
@router.post("/{alert_id}/resolve")
|
||||
async def resolve_alert(
|
||||
alert_id: int,
|
||||
resolution: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Alert).where(Alert.id == alert_id))
|
||||
alert = result.scalar_one_or_none()
|
||||
|
||||
if not alert:
|
||||
return {"error": "Alert not found"}
|
||||
|
||||
alert.status = AlertStatus.RESOLVED
|
||||
alert.resolved_by = current_user.id
|
||||
alert.resolved_at = datetime.utcnow()
|
||||
alert.resolution_notes = resolution
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Alert resolved", "alert": alert.to_dict()}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_alert_stats(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
critical_query = select(func.count(Alert.id)).where(
|
||||
Alert.severity == AlertSeverity.CRITICAL,
|
||||
Alert.status == AlertStatus.ACTIVE,
|
||||
)
|
||||
warning_query = select(func.count(Alert.id)).where(
|
||||
Alert.severity == AlertSeverity.WARNING,
|
||||
Alert.status == AlertStatus.ACTIVE,
|
||||
)
|
||||
info_query = select(func.count(Alert.id)).where(
|
||||
Alert.severity == AlertSeverity.INFO,
|
||||
Alert.status == AlertStatus.ACTIVE,
|
||||
)
|
||||
|
||||
critical_result = await db.execute(critical_query)
|
||||
warning_result = await db.execute(warning_query)
|
||||
info_result = await db.execute(info_query)
|
||||
|
||||
return {
|
||||
"critical": critical_result.scalar() or 0,
|
||||
"warning": warning_result.scalar() or 0,
|
||||
"info": info_result.scalar() or 0,
|
||||
}
|
||||
108
backend/app/api/v1/auth.py
Normal file
108
backend/app/api/v1/auth.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
blacklist_token,
|
||||
get_current_user,
|
||||
verify_password,
|
||||
)
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.token import Token
|
||||
from app.schemas.user import UserCreate, UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, username, email, password_hash, role, is_active FROM users WHERE username = :username"
|
||||
),
|
||||
{"username": form_data.username},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
user = User()
|
||||
user.id = row[0]
|
||||
user.username = row[1]
|
||||
user.email = row[2]
|
||||
user.password_hash = row[3]
|
||||
user.role = row[4]
|
||||
user.is_active = row[5]
|
||||
|
||||
if not verify_password(form_data.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User is inactive",
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": user.id})
|
||||
refresh_token = create_refresh_token(data={"sub": user.id})
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
access_token = create_access_token(data={"sub": current_user.id})
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"user": {
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"role": current_user.role,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout():
|
||||
return {"message": "Successfully logged out"}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
return {
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"email": current_user.email,
|
||||
"role": current_user.role,
|
||||
"is_active": current_user.is_active,
|
||||
"created_at": current_user.created_at,
|
||||
}
|
||||
431
backend/app/api/v1/collected_data.py
Normal file
431
backend/app/api/v1/collected_data.py
Normal file
@@ -0,0 +1,431 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import json
|
||||
import csv
|
||||
import io
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import get_current_user
|
||||
from app.models.collected_data import CollectedData
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_collected_data(
|
||||
source: Optional[str] = Query(None, description="数据源过滤"),
|
||||
data_type: Optional[str] = Query(None, description="数据类型过滤"),
|
||||
country: Optional[str] = Query(None, description="国家过滤"),
|
||||
search: Optional[str] = Query(None, description="搜索名称"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""查询采集的数据列表"""
|
||||
|
||||
# Build WHERE clause
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if source:
|
||||
conditions.append("source = :source")
|
||||
params["source"] = source
|
||||
if data_type:
|
||||
conditions.append("data_type = :data_type")
|
||||
params["data_type"] = data_type
|
||||
if country:
|
||||
conditions.append("country = :country")
|
||||
params["country"] = country
|
||||
if search:
|
||||
conditions.append("(name ILIKE :search OR title ILIKE :search)")
|
||||
params["search"] = f"%{search}%"
|
||||
|
||||
where_sql = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Query total count
|
||||
count_query = text(f"SELECT COUNT(*) FROM collected_data WHERE {where_sql}")
|
||||
count_result = await db.execute(count_query, params)
|
||||
total = count_result.scalar()
|
||||
|
||||
# Query data
|
||||
query = text(f"""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
FROM collected_data
|
||||
WHERE {where_sql}
|
||||
ORDER BY collected_at DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
params["limit"] = page_size
|
||||
params["offset"] = offset
|
||||
|
||||
result = await db.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
data = []
|
||||
for row in rows:
|
||||
data.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"source": row[1],
|
||||
"source_id": row[2],
|
||||
"data_type": row[3],
|
||||
"name": row[4],
|
||||
"title": row[5],
|
||||
"description": row[6],
|
||||
"country": row[7],
|
||||
"city": row[8],
|
||||
"latitude": row[9],
|
||||
"longitude": row[10],
|
||||
"value": row[11],
|
||||
"unit": row[12],
|
||||
"metadata": row[13],
|
||||
"collected_at": row[14].isoformat() if row[14] else None,
|
||||
"reference_date": row[15].isoformat() if row[15] else None,
|
||||
"is_valid": row[16],
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_data_summary(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取数据汇总统计"""
|
||||
|
||||
# By source and data_type
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT source, data_type, COUNT(*) as count
|
||||
FROM collected_data
|
||||
GROUP BY source, data_type
|
||||
ORDER BY source, data_type
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
by_source = {}
|
||||
total = 0
|
||||
for row in rows:
|
||||
source = row[0]
|
||||
data_type = row[1]
|
||||
count = row[2]
|
||||
|
||||
if source not in by_source:
|
||||
by_source[source] = {}
|
||||
by_source[source][data_type] = count
|
||||
total += count
|
||||
|
||||
# Total by source
|
||||
source_totals = await db.execute(
|
||||
text("""
|
||||
SELECT source, COUNT(*) as count
|
||||
FROM collected_data
|
||||
GROUP BY source
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
)
|
||||
source_rows = source_totals.fetchall()
|
||||
|
||||
return {
|
||||
"total_records": total,
|
||||
"by_source": by_source,
|
||||
"source_totals": [{"source": row[0], "count": row[1]} for row in source_rows],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sources")
|
||||
async def get_data_sources(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取所有数据源列表"""
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT source FROM collected_data ORDER BY source
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
return {
|
||||
"sources": [row[0] for row in rows],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/types")
|
||||
async def get_data_types(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取所有数据类型列表"""
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT data_type FROM collected_data ORDER BY data_type
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
return {
|
||||
"data_types": [row[0] for row in rows],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/countries")
|
||||
async def get_countries(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取所有国家列表"""
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT country FROM collected_data
|
||||
WHERE country IS NOT NULL AND country != ''
|
||||
ORDER BY country
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
return {
|
||||
"countries": [row[0] for row in rows],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{data_id}")
|
||||
async def get_collected_data(
|
||||
data_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单条采集数据详情"""
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
FROM collected_data
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{"id": data_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="数据不存在",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": row[0],
|
||||
"source": row[1],
|
||||
"source_id": row[2],
|
||||
"data_type": row[3],
|
||||
"name": row[4],
|
||||
"title": row[5],
|
||||
"description": row[6],
|
||||
"country": row[7],
|
||||
"city": row[8],
|
||||
"latitude": row[9],
|
||||
"longitude": row[10],
|
||||
"value": row[11],
|
||||
"unit": row[12],
|
||||
"metadata": row[13],
|
||||
"collected_at": row[14].isoformat() if row[14] else None,
|
||||
"reference_date": row[15].isoformat() if row[15] else None,
|
||||
"is_valid": row[16],
|
||||
}
|
||||
|
||||
|
||||
def build_where_clause(
|
||||
source: Optional[str], data_type: Optional[str], country: Optional[str], search: Optional[str]
|
||||
):
|
||||
"""Build WHERE clause and params for queries"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if source:
|
||||
conditions.append("source = :source")
|
||||
params["source"] = source
|
||||
if data_type:
|
||||
conditions.append("data_type = :data_type")
|
||||
params["data_type"] = data_type
|
||||
if country:
|
||||
conditions.append("country = :country")
|
||||
params["country"] = country
|
||||
if search:
|
||||
conditions.append("(name ILIKE :search OR title ILIKE :search)")
|
||||
params["search"] = f"%{search}%"
|
||||
|
||||
where_sql = " AND ".join(conditions) if conditions else "1=1"
|
||||
return where_sql, params
|
||||
|
||||
|
||||
@router.get("/export/json")
|
||||
async def export_json(
|
||||
source: Optional[str] = Query(None, description="数据源过滤"),
|
||||
data_type: Optional[str] = Query(None, description="数据类型过滤"),
|
||||
country: Optional[str] = Query(None, description="国家过滤"),
|
||||
search: Optional[str] = Query(None, description="搜索名称"),
|
||||
limit: int = Query(10000, ge=1, le=50000, description="最大导出数量"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""导出数据为 JSON 格式"""
|
||||
|
||||
where_sql, params = build_where_clause(source, data_type, country, search)
|
||||
params["limit"] = limit
|
||||
|
||||
query = text(f"""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
FROM collected_data
|
||||
WHERE {where_sql}
|
||||
ORDER BY collected_at DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
result = await db.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
data = []
|
||||
for row in rows:
|
||||
data.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"source": row[1],
|
||||
"source_id": row[2],
|
||||
"data_type": row[3],
|
||||
"name": row[4],
|
||||
"title": row[5],
|
||||
"description": row[6],
|
||||
"country": row[7],
|
||||
"city": row[8],
|
||||
"latitude": row[9],
|
||||
"longitude": row[10],
|
||||
"value": row[11],
|
||||
"unit": row[12],
|
||||
"metadata": row[13],
|
||||
"collected_at": row[14].isoformat() if row[14] else None,
|
||||
"reference_date": row[15].isoformat() if row[15] else None,
|
||||
"is_valid": row[16],
|
||||
}
|
||||
)
|
||||
|
||||
json_str = json.dumps({"data": data, "total": len(data)}, ensure_ascii=False, indent=2)
|
||||
|
||||
return StreamingResponse(
|
||||
io.StringIO(json_str),
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=collected_data_{source or 'all'}.json"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/export/csv")
|
||||
async def export_csv(
|
||||
source: Optional[str] = Query(None, description="数据源过滤"),
|
||||
data_type: Optional[str] = Query(None, description="数据类型过滤"),
|
||||
country: Optional[str] = Query(None, description="国家过滤"),
|
||||
search: Optional[str] = Query(None, description="搜索名称"),
|
||||
limit: int = Query(10000, ge=1, le=50000, description="最大导出数量"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""导出数据为 CSV 格式"""
|
||||
|
||||
where_sql, params = build_where_clause(source, data_type, country, search)
|
||||
params["limit"] = limit
|
||||
|
||||
query = text(f"""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
FROM collected_data
|
||||
WHERE {where_sql}
|
||||
ORDER BY collected_at DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
result = await db.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write header
|
||||
writer.writerow(
|
||||
[
|
||||
"ID",
|
||||
"Source",
|
||||
"Source ID",
|
||||
"Type",
|
||||
"Name",
|
||||
"Title",
|
||||
"Description",
|
||||
"Country",
|
||||
"City",
|
||||
"Latitude",
|
||||
"Longitude",
|
||||
"Value",
|
||||
"Unit",
|
||||
"Metadata",
|
||||
"Collected At",
|
||||
"Reference Date",
|
||||
"Is Valid",
|
||||
]
|
||||
)
|
||||
|
||||
# Write data
|
||||
for row in rows:
|
||||
writer.writerow(
|
||||
[
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
row[3],
|
||||
row[4],
|
||||
row[5],
|
||||
row[6],
|
||||
row[7],
|
||||
row[8],
|
||||
row[9],
|
||||
row[10],
|
||||
row[11],
|
||||
row[12],
|
||||
json.dumps(row[13]) if row[13] else "",
|
||||
row[14].isoformat() if row[14] else "",
|
||||
row[15].isoformat() if row[15] else "",
|
||||
row[16],
|
||||
]
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
io.StringIO(output.getvalue()),
|
||||
media_type="text/csv",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=collected_data_{source or 'all'}.csv"
|
||||
},
|
||||
)
|
||||
239
backend/app/api/v1/dashboard.py
Normal file
239
backend/app/api/v1/dashboard.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Dashboard API with caching and optimizations"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.models.datasource import DataSource
|
||||
from app.models.datasource_config import DataSourceConfig
|
||||
from app.models.alert import Alert, AlertSeverity
|
||||
from app.models.task import CollectionTask
|
||||
from app.core.security import get_current_user
|
||||
from app.core.cache import cache
|
||||
|
||||
# Built-in collectors info (mirrored from datasources.py)
|
||||
COLLECTOR_INFO = {
|
||||
"top500": {
|
||||
"id": 1,
|
||||
"name": "TOP500 Supercomputers",
|
||||
"module": "L1",
|
||||
"priority": "P0",
|
||||
"frequency_hours": 4,
|
||||
},
|
||||
"epoch_ai_gpu": {
|
||||
"id": 2,
|
||||
"name": "Epoch AI GPU Clusters",
|
||||
"module": "L1",
|
||||
"priority": "P0",
|
||||
"frequency_hours": 6,
|
||||
},
|
||||
"huggingface_models": {
|
||||
"id": 3,
|
||||
"name": "HuggingFace Models",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 12,
|
||||
},
|
||||
"huggingface_datasets": {
|
||||
"id": 4,
|
||||
"name": "HuggingFace Datasets",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 12,
|
||||
},
|
||||
"huggingface_spaces": {
|
||||
"id": 5,
|
||||
"name": "HuggingFace Spaces",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 24,
|
||||
},
|
||||
"peeringdb_ixp": {
|
||||
"id": 6,
|
||||
"name": "PeeringDB IXP",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 24,
|
||||
},
|
||||
"peeringdb_network": {
|
||||
"id": 7,
|
||||
"name": "PeeringDB Networks",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 48,
|
||||
},
|
||||
"peeringdb_facility": {
|
||||
"id": 8,
|
||||
"name": "PeeringDB Facilities",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 48,
|
||||
},
|
||||
"telegeography_cables": {
|
||||
"id": 9,
|
||||
"name": "Submarine Cables",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 168,
|
||||
},
|
||||
"telegeography_landing": {
|
||||
"id": 10,
|
||||
"name": "Cable Landing Points",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 168,
|
||||
},
|
||||
"telegeography_systems": {
|
||||
"id": 11,
|
||||
"name": "Cable Systems",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 168,
|
||||
},
|
||||
}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get dashboard statistics with caching"""
|
||||
cache_key = "dashboard:stats"
|
||||
|
||||
cached_result = cache.get(cache_key)
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Count built-in collectors
|
||||
built_in_count = len(COLLECTOR_INFO)
|
||||
built_in_active = built_in_count # Built-in are always "active" for counting purposes
|
||||
|
||||
# Count custom configs from database
|
||||
result = await db.execute(select(func.count(DataSourceConfig.id)))
|
||||
custom_count = result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(DataSourceConfig.id)).where(DataSourceConfig.is_active == True)
|
||||
)
|
||||
custom_active = result.scalar() or 0
|
||||
|
||||
# Total datasources
|
||||
total_datasources = built_in_count + custom_count
|
||||
active_datasources = built_in_active + custom_active
|
||||
|
||||
# Tasks today (from database)
|
||||
result = await db.execute(
|
||||
select(func.count(CollectionTask.id)).where(CollectionTask.started_at >= today_start)
|
||||
)
|
||||
tasks_today = result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(CollectionTask.id)).where(
|
||||
CollectionTask.status == "success",
|
||||
CollectionTask.started_at >= today_start,
|
||||
)
|
||||
)
|
||||
success_tasks = result.scalar() or 0
|
||||
success_rate = (success_tasks / tasks_today * 100) if tasks_today > 0 else 0
|
||||
|
||||
# Alerts
|
||||
result = await db.execute(
|
||||
select(func.count(Alert.id)).where(
|
||||
Alert.severity == AlertSeverity.CRITICAL,
|
||||
Alert.status == "active",
|
||||
)
|
||||
)
|
||||
critical_alerts = result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(Alert.id)).where(
|
||||
Alert.severity == AlertSeverity.WARNING,
|
||||
Alert.status == "active",
|
||||
)
|
||||
)
|
||||
warning_alerts = result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(Alert.id)).where(
|
||||
Alert.severity == AlertSeverity.INFO,
|
||||
Alert.status == "active",
|
||||
)
|
||||
)
|
||||
info_alerts = result.scalar() or 0
|
||||
|
||||
response = {
|
||||
"total_datasources": total_datasources,
|
||||
"active_datasources": active_datasources,
|
||||
"tasks_today": tasks_today,
|
||||
"success_rate": round(success_rate, 1),
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
"alerts": {
|
||||
"critical": critical_alerts,
|
||||
"warning": warning_alerts,
|
||||
"info": info_alerts,
|
||||
},
|
||||
}
|
||||
|
||||
cache.set(cache_key, response, expire_seconds=60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get dashboard summary by module with caching"""
|
||||
cache_key = "dashboard:summary"
|
||||
|
||||
cached_result = cache.get(cache_key)
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
# Count by module for built-in collectors
|
||||
builtin_by_module = {}
|
||||
for name, info in COLLECTOR_INFO.items():
|
||||
module = info["module"]
|
||||
if module not in builtin_by_module:
|
||||
builtin_by_module[module] = {"datasources": 0, "sources": []}
|
||||
builtin_by_module[module]["datasources"] += 1
|
||||
builtin_by_module[module]["sources"].append(info["name"])
|
||||
|
||||
# Count custom configs by module (default to L3 for custom)
|
||||
result = await db.execute(
|
||||
select(DataSourceConfig.source_type, func.count(DataSourceConfig.id).label("count"))
|
||||
.where(DataSourceConfig.is_active == True)
|
||||
.group_by(DataSourceConfig.source_type)
|
||||
)
|
||||
custom_rows = result.fetchall()
|
||||
|
||||
for row in custom_rows:
|
||||
source_type = row.source_type
|
||||
module = "L3" # Custom configs default to L3
|
||||
if module not in builtin_by_module:
|
||||
builtin_by_module[module] = {"datasources": 0, "sources": []}
|
||||
builtin_by_module[module]["datasources"] += row.count
|
||||
builtin_by_module[module]["sources"].append(f"自定义 ({source_type})")
|
||||
|
||||
summary = {}
|
||||
for module, data in builtin_by_module.items():
|
||||
summary[module] = {
|
||||
"datasources": data["datasources"],
|
||||
"total_records": 0, # Built-in don't track this in dashboard stats
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
response = {"modules": summary, "last_updated": datetime.utcnow().isoformat()}
|
||||
|
||||
cache.set(cache_key, response, expire_seconds=300)
|
||||
|
||||
return response
|
||||
309
backend/app/api/v1/datasource_config.py
Normal file
309
backend/app/api/v1/datasource_config.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""DataSourceConfig API for user-defined data sources"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import base64
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel, Field
|
||||
import httpx
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.models.datasource_config import DataSourceConfig
|
||||
from app.core.security import get_current_user
|
||||
from app.core.cache import cache
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class DataSourceConfigCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
source_type: str = Field(..., description="http, api, database")
|
||||
endpoint: str = Field(..., max_length=500)
|
||||
auth_type: str = Field(default="none", description="none, bearer, api_key, basic")
|
||||
auth_config: dict = Field(default={})
|
||||
headers: dict = Field(default={})
|
||||
config: dict = Field(default={"timeout": 30, "retry": 3})
|
||||
|
||||
|
||||
class DataSourceConfigUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
endpoint: Optional[str] = Field(None, max_length=500)
|
||||
auth_type: Optional[str] = None
|
||||
auth_config: Optional[dict] = None
|
||||
headers: Optional[dict] = None
|
||||
config: Optional[dict] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class DataSourceConfigResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
source_type: str
|
||||
endpoint: str
|
||||
auth_type: str
|
||||
headers: dict
|
||||
config: dict
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
async def test_endpoint(
|
||||
endpoint: str,
|
||||
auth_type: str,
|
||||
auth_config: dict,
|
||||
headers: dict,
|
||||
config: dict,
|
||||
) -> dict:
|
||||
"""Test an endpoint connection"""
|
||||
timeout = config.get("timeout", 30)
|
||||
test_headers = headers.copy()
|
||||
|
||||
# Add auth headers
|
||||
if auth_type == "bearer" and auth_config.get("token"):
|
||||
test_headers["Authorization"] = f"Bearer {auth_config['token']}"
|
||||
elif auth_type == "api_key" and auth_config.get("api_key"):
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
test_headers[key_name] = auth_config["api_key"]
|
||||
elif auth_type == "basic":
|
||||
username = auth_config.get("username", "")
|
||||
password = auth_config.get("password", "")
|
||||
credentials = f"{username}:{password}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
test_headers["Authorization"] = f"Basic {encoded}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(endpoint, headers=test_headers)
|
||||
response.raise_for_status()
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"success": True,
|
||||
"response_time_ms": response.elapsed.total_seconds() * 1000,
|
||||
"data_preview": str(response.json()[:3])
|
||||
if response.headers.get("content-type", "").startswith("application/json")
|
||||
else response.text[:200],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
async def list_configs(
|
||||
active_only: bool = False,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all user-defined data source configurations"""
|
||||
query = select(DataSourceConfig)
|
||||
if active_only:
|
||||
query = query.where(DataSourceConfig.is_active == True)
|
||||
query = query.order_by(DataSourceConfig.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
configs = result.scalars().all()
|
||||
|
||||
return {
|
||||
"total": len(configs),
|
||||
"data": [
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"description": c.description,
|
||||
"source_type": c.source_type,
|
||||
"endpoint": c.endpoint,
|
||||
"auth_type": c.auth_type,
|
||||
"headers": c.headers,
|
||||
"config": c.config,
|
||||
"is_active": c.is_active,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
|
||||
}
|
||||
for c in configs
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/configs/{config_id}")
|
||||
async def get_config(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get a single data source configuration"""
|
||||
result = await db.execute(select(DataSourceConfig).where(DataSourceConfig.id == config_id))
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
return {
|
||||
"id": config.id,
|
||||
"name": config.name,
|
||||
"description": config.description,
|
||||
"source_type": config.source_type,
|
||||
"endpoint": config.endpoint,
|
||||
"auth_type": config.auth_type,
|
||||
"auth_config": {}, # Don't return sensitive data
|
||||
"headers": config.headers,
|
||||
"config": config.config,
|
||||
"is_active": config.is_active,
|
||||
"created_at": config.created_at.isoformat() if config.created_at else None,
|
||||
"updated_at": config.updated_at.isoformat() if config.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/configs")
|
||||
async def create_config(
|
||||
config_data: DataSourceConfigCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new data source configuration"""
|
||||
config = DataSourceConfig(
|
||||
name=config_data.name,
|
||||
description=config_data.description,
|
||||
source_type=config_data.source_type,
|
||||
endpoint=config_data.endpoint,
|
||||
auth_type=config_data.auth_type,
|
||||
auth_config=config_data.auth_config,
|
||||
headers=config_data.headers,
|
||||
config=config_data.config,
|
||||
)
|
||||
|
||||
db.add(config)
|
||||
await db.commit()
|
||||
await db.refresh(config)
|
||||
|
||||
cache.delete_pattern("datasource_configs:*")
|
||||
|
||||
return {
|
||||
"id": config.id,
|
||||
"name": config.name,
|
||||
"message": "Configuration created successfully",
|
||||
}
|
||||
|
||||
|
||||
@router.put("/configs/{config_id}")
|
||||
async def update_config(
|
||||
config_id: int,
|
||||
config_data: DataSourceConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a data source configuration"""
|
||||
result = await db.execute(select(DataSourceConfig).where(DataSourceConfig.id == config_id))
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
update_data = config_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(config)
|
||||
|
||||
cache.delete_pattern("datasource_configs:*")
|
||||
|
||||
return {
|
||||
"id": config.id,
|
||||
"name": config.name,
|
||||
"message": "Configuration updated successfully",
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/configs/{config_id}")
|
||||
async def delete_config(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a data source configuration"""
|
||||
result = await db.execute(select(DataSourceConfig).where(DataSourceConfig.id == config_id))
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
await db.delete(config)
|
||||
await db.commit()
|
||||
|
||||
cache.delete_pattern("datasource_configs:*")
|
||||
|
||||
return {"message": "Configuration deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/configs/{config_id}/test")
|
||||
async def test_config(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Test a data source configuration"""
|
||||
result = await db.execute(select(DataSourceConfig).where(DataSourceConfig.id == config_id))
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
try:
|
||||
result = await test_endpoint(
|
||||
endpoint=config.endpoint,
|
||||
auth_type=config.auth_type,
|
||||
auth_config=config.auth_config or {},
|
||||
headers=config.headers or {},
|
||||
config=config.config or {},
|
||||
)
|
||||
return result
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"HTTP Error: {e.response.status_code}",
|
||||
"message": str(e),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Connection failed",
|
||||
"message": str(e),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/configs/test")
|
||||
async def test_new_config(
|
||||
config_data: DataSourceConfigCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Test a new data source configuration without saving"""
|
||||
try:
|
||||
result = await test_endpoint(
|
||||
endpoint=config_data.endpoint,
|
||||
auth_type=config_data.auth_type,
|
||||
auth_config=config_data.auth_config or {},
|
||||
headers=config_data.headers or {},
|
||||
config=config_data.config or {},
|
||||
)
|
||||
return result
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"HTTP Error: {e.response.status_code}",
|
||||
"message": str(e),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Connection failed",
|
||||
"message": str(e),
|
||||
}
|
||||
258
backend/app/api/v1/datasources.py
Normal file
258
backend/app/api/v1/datasources.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.models.datasource import DataSource
|
||||
from app.core.security import get_current_user
|
||||
from app.services.collectors.registry import collector_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
COLLECTOR_INFO = {
|
||||
"top500": {
|
||||
"id": 1,
|
||||
"name": "TOP500 Supercomputers",
|
||||
"module": "L1",
|
||||
"priority": "P0",
|
||||
"frequency_hours": 4,
|
||||
},
|
||||
"epoch_ai_gpu": {
|
||||
"id": 2,
|
||||
"name": "Epoch AI GPU Clusters",
|
||||
"module": "L1",
|
||||
"priority": "P0",
|
||||
"frequency_hours": 6,
|
||||
},
|
||||
"huggingface_models": {
|
||||
"id": 3,
|
||||
"name": "HuggingFace Models",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 12,
|
||||
},
|
||||
"huggingface_datasets": {
|
||||
"id": 4,
|
||||
"name": "HuggingFace Datasets",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 12,
|
||||
},
|
||||
"huggingface_spaces": {
|
||||
"id": 5,
|
||||
"name": "HuggingFace Spaces",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 24,
|
||||
},
|
||||
"peeringdb_ixp": {
|
||||
"id": 6,
|
||||
"name": "PeeringDB IXP",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 24,
|
||||
},
|
||||
"peeringdb_network": {
|
||||
"id": 7,
|
||||
"name": "PeeringDB Networks",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 48,
|
||||
},
|
||||
"peeringdb_facility": {
|
||||
"id": 8,
|
||||
"name": "PeeringDB Facilities",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 48,
|
||||
},
|
||||
"telegeography_cables": {
|
||||
"id": 9,
|
||||
"name": "Submarine Cables",
|
||||
"module": "L2",
|
||||
"priority": "P1",
|
||||
"frequency_hours": 168,
|
||||
},
|
||||
"telegeography_landing": {
|
||||
"id": 10,
|
||||
"name": "Cable Landing Points",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 168,
|
||||
},
|
||||
"telegeography_systems": {
|
||||
"id": 11,
|
||||
"name": "Cable Systems",
|
||||
"module": "L2",
|
||||
"priority": "P2",
|
||||
"frequency_hours": 168,
|
||||
},
|
||||
}
|
||||
|
||||
ID_TO_COLLECTOR = {info["id"]: name for name, info in COLLECTOR_INFO.items()}
|
||||
COLLECTOR_TO_ID = {name: info["id"] for name, info in COLLECTOR_INFO.items()}
|
||||
|
||||
|
||||
def get_collector_name(source_id: str) -> Optional[str]:
|
||||
try:
|
||||
numeric_id = int(source_id)
|
||||
if numeric_id in ID_TO_COLLECTOR:
|
||||
return ID_TO_COLLECTOR[numeric_id]
|
||||
except ValueError:
|
||||
pass
|
||||
if source_id in COLLECTOR_INFO:
|
||||
return source_id
|
||||
return None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_datasources(
|
||||
module: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
priority: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(DataSource)
|
||||
|
||||
filters = []
|
||||
if module:
|
||||
filters.append(DataSource.module == module)
|
||||
if is_active is not None:
|
||||
filters.append(DataSource.is_active == is_active)
|
||||
if priority:
|
||||
filters.append(DataSource.priority == priority)
|
||||
|
||||
if filters:
|
||||
query = query.where(*filters)
|
||||
|
||||
result = await db.execute(query)
|
||||
datasources = result.scalars().all()
|
||||
|
||||
collector_list = []
|
||||
for name, info in COLLECTOR_INFO.items():
|
||||
is_active_status = collector_registry.is_active(name)
|
||||
collector_list.append(
|
||||
{
|
||||
"id": info["id"],
|
||||
"name": info["name"],
|
||||
"module": info["module"],
|
||||
"priority": info["priority"],
|
||||
"frequency": f"{info['frequency_hours']}h",
|
||||
"is_active": is_active_status,
|
||||
"collector_class": name,
|
||||
}
|
||||
)
|
||||
|
||||
if module:
|
||||
collector_list = [c for c in collector_list if c["module"] == module]
|
||||
if priority:
|
||||
collector_list = [c for c in collector_list if c["priority"] == priority]
|
||||
|
||||
return {
|
||||
"total": len(collector_list),
|
||||
"data": collector_list,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{source_id}")
|
||||
async def get_datasource(
|
||||
source_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
collector_name = get_collector_name(source_id)
|
||||
if not collector_name:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
info = COLLECTOR_INFO[collector_name]
|
||||
return {
|
||||
"id": info["id"],
|
||||
"name": info["name"],
|
||||
"module": info["module"],
|
||||
"priority": info["priority"],
|
||||
"frequency": f"{info['frequency_hours']}h",
|
||||
"collector_class": collector_name,
|
||||
"is_active": collector_registry.is_active(collector_name),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{source_id}/enable")
|
||||
async def enable_datasource(
|
||||
source_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
collector_name = get_collector_name(source_id)
|
||||
if not collector_name:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
collector_registry.set_active(collector_name, True)
|
||||
return {"status": "enabled", "source_id": source_id}
|
||||
|
||||
|
||||
@router.post("/{source_id}/disable")
|
||||
async def disable_datasource(
|
||||
source_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
collector_name = get_collector_name(source_id)
|
||||
if not collector_name:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
collector_registry.set_active(collector_name, False)
|
||||
return {"status": "disabled", "source_id": source_id}
|
||||
|
||||
|
||||
@router.get("/{source_id}/stats")
|
||||
async def get_datasource_stats(
|
||||
source_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
collector_name = get_collector_name(source_id)
|
||||
if not collector_name:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
info = COLLECTOR_INFO[collector_name]
|
||||
total_query = select(func.count(DataSource.id)).where(DataSource.source == info["name"])
|
||||
result = await db.execute(total_query)
|
||||
total = result.scalar() or 0
|
||||
|
||||
return {
|
||||
"source_id": source_id,
|
||||
"collector_name": collector_name,
|
||||
"name": info["name"],
|
||||
"total_records": total,
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{source_id}/trigger")
|
||||
async def trigger_datasource(
|
||||
source_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
collector_name = get_collector_name(source_id)
|
||||
if not collector_name:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
from app.services.scheduler import run_collector_now
|
||||
|
||||
if not collector_registry.is_active(collector_name):
|
||||
raise HTTPException(status_code=400, detail="Data source is disabled")
|
||||
|
||||
success = run_collector_now(collector_name)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "triggered",
|
||||
"source_id": source_id,
|
||||
"collector_name": collector_name,
|
||||
"message": f"Collector '{collector_name}' has been triggered",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to trigger collector '{collector_name}'",
|
||||
)
|
||||
110
backend/app/api/v1/settings.py
Normal file
110
backend/app/api/v1/settings.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from app.models.user import User
|
||||
from app.core.security import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
default_settings = {
|
||||
"system": {
|
||||
"system_name": "智能星球",
|
||||
"refresh_interval": 60,
|
||||
"auto_refresh": True,
|
||||
"data_retention_days": 30,
|
||||
"max_concurrent_tasks": 5,
|
||||
},
|
||||
"notifications": {
|
||||
"email_enabled": False,
|
||||
"email_address": "",
|
||||
"critical_alerts": True,
|
||||
"warning_alerts": True,
|
||||
"daily_summary": False,
|
||||
},
|
||||
"security": {
|
||||
"session_timeout": 60,
|
||||
"max_login_attempts": 5,
|
||||
"password_policy": "medium",
|
||||
},
|
||||
}
|
||||
|
||||
system_settings = default_settings["system"].copy()
|
||||
notification_settings = default_settings["notifications"].copy()
|
||||
security_settings = default_settings["security"].copy()
|
||||
|
||||
|
||||
class SystemSettingsUpdate(BaseModel):
|
||||
system_name: str = "智能星球"
|
||||
refresh_interval: int = 60
|
||||
auto_refresh: bool = True
|
||||
data_retention_days: int = 30
|
||||
max_concurrent_tasks: int = 5
|
||||
|
||||
|
||||
class NotificationSettingsUpdate(BaseModel):
|
||||
email_enabled: bool = False
|
||||
email_address: Optional[EmailStr] = None
|
||||
critical_alerts: bool = True
|
||||
warning_alerts: bool = True
|
||||
daily_summary: bool = False
|
||||
|
||||
|
||||
class SecuritySettingsUpdate(BaseModel):
|
||||
session_timeout: int = 60
|
||||
max_login_attempts: int = 5
|
||||
password_policy: str = "medium"
|
||||
|
||||
|
||||
@router.get("/system")
|
||||
async def get_system_settings(current_user: User = Depends(get_current_user)):
|
||||
return {"system": system_settings}
|
||||
|
||||
|
||||
@router.put("/system")
|
||||
async def update_system_settings(
|
||||
settings: SystemSettingsUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
global system_settings
|
||||
system_settings = settings.model_dump()
|
||||
return {"status": "updated", "system": system_settings}
|
||||
|
||||
|
||||
@router.get("/notifications")
|
||||
async def get_notification_settings(current_user: User = Depends(get_current_user)):
|
||||
return {"notifications": notification_settings}
|
||||
|
||||
|
||||
@router.put("/notifications")
|
||||
async def update_notification_settings(
|
||||
settings: NotificationSettingsUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
global notification_settings
|
||||
notification_settings = settings.model_dump()
|
||||
return {"status": "updated", "notifications": notification_settings}
|
||||
|
||||
|
||||
@router.get("/security")
|
||||
async def get_security_settings(current_user: User = Depends(get_current_user)):
|
||||
return {"security": security_settings}
|
||||
|
||||
|
||||
@router.put("/security")
|
||||
async def update_security_settings(
|
||||
settings: SecuritySettingsUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
global security_settings
|
||||
security_settings = settings.model_dump()
|
||||
return {"status": "updated", "security": security_settings}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_all_settings(current_user: User = Depends(get_current_user)):
|
||||
return {
|
||||
"system": system_settings,
|
||||
"notifications": notification_settings,
|
||||
"security": security_settings,
|
||||
}
|
||||
157
backend/app/api/v1/tasks.py
Normal file
157
backend/app/api/v1/tasks.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import get_current_user
|
||||
from app.services.collectors.registry import collector_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_tasks(
|
||||
datasource_id: int = None,
|
||||
status: str = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
offset = (page - 1) * page_size
|
||||
query = """
|
||||
SELECT ct.id, ct.datasource_id, ds.name as datasource_name, ct.status,
|
||||
ct.started_at, ct.completed_at, ct.records_processed, ct.error_message
|
||||
FROM collection_tasks ct
|
||||
JOIN data_sources ds ON ct.datasource_id = ds.id
|
||||
WHERE 1=1
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM collection_tasks ct WHERE 1=1"
|
||||
params = {}
|
||||
|
||||
if datasource_id:
|
||||
query += " AND ct.datasource_id = :datasource_id"
|
||||
count_query += " WHERE ct.datasource_id = :datasource_id"
|
||||
params["datasource_id"] = datasource_id
|
||||
if status:
|
||||
query += " AND ct.status = :status"
|
||||
count_query += " AND ct.status = :status"
|
||||
params["status"] = status
|
||||
|
||||
query += f" ORDER BY ct.created_at DESC LIMIT {page_size} OFFSET {offset}"
|
||||
|
||||
result = await db.execute(text(query), params)
|
||||
tasks = result.fetchall()
|
||||
|
||||
count_result = await db.execute(text(count_query), params)
|
||||
total = count_result.scalar()
|
||||
|
||||
return {
|
||||
"total": total or 0,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"data": [
|
||||
{
|
||||
"id": t[0],
|
||||
"datasource_id": t[1],
|
||||
"datasource_name": t[2],
|
||||
"status": t[3],
|
||||
"started_at": t[4].isoformat() if t[4] else None,
|
||||
"completed_at": t[5].isoformat() if t[5] else None,
|
||||
"records_processed": t[6],
|
||||
"error_message": t[7],
|
||||
}
|
||||
for t in tasks
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task(
|
||||
task_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT ct.id, ct.datasource_id, ds.name as datasource_name, ct.status,
|
||||
ct.started_at, ct.completed_at, ct.records_processed, ct.error_message
|
||||
FROM collection_tasks ct
|
||||
JOIN data_sources ds ON ct.datasource_id = ds.id
|
||||
WHERE ct.id = :id
|
||||
"""),
|
||||
{"id": task_id},
|
||||
)
|
||||
task = result.fetchone()
|
||||
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task not found",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": task[0],
|
||||
"datasource_id": task[1],
|
||||
"datasource_name": task[2],
|
||||
"status": task[3],
|
||||
"started_at": task[4].isoformat() if task[4] else None,
|
||||
"completed_at": task[5].isoformat() if task[5] else None,
|
||||
"records_processed": task[6],
|
||||
"error_message": task[7],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/datasources/{source_id}/trigger")
|
||||
async def trigger_collection(
|
||||
source_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
text("SELECT id, name, collector_class FROM data_sources WHERE id = :id"),
|
||||
{"id": source_id},
|
||||
)
|
||||
datasource = result.fetchone()
|
||||
|
||||
if not datasource:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Data source not found",
|
||||
)
|
||||
|
||||
collector_class_name = datasource[2]
|
||||
collector_name = collector_class_name.lower().replace("collector", "")
|
||||
|
||||
collector = collector_registry.get(collector_name)
|
||||
if not collector:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Collector {collector_name} not found",
|
||||
)
|
||||
|
||||
result = await collector.run(db)
|
||||
|
||||
await db.execute(
|
||||
text("""
|
||||
INSERT INTO collection_tasks (datasource_id, status, records_processed, error_message, started_at, completed_at, created_at)
|
||||
VALUES (:datasource_id, :status, :records_processed, :error_message, :started_at, :completed_at, NOW())
|
||||
"""),
|
||||
{
|
||||
"datasource_id": source_id,
|
||||
"status": result.get("status", "unknown"),
|
||||
"records_processed": result.get("records_processed", 0),
|
||||
"error_message": result.get("error"),
|
||||
"started_at": datetime.utcnow(),
|
||||
"completed_at": datetime.utcnow(),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Collection task executed",
|
||||
"result": result,
|
||||
}
|
||||
263
backend/app/api/v1/users.py
Normal file
263
backend/app/api/v1/users.py
Normal file
@@ -0,0 +1,263 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.security import get_current_user, get_password_hash
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserCreate, UserResponse, UserUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def check_permission(current_user: User, required_roles: List[str]) -> bool:
|
||||
user_role_value = (
|
||||
current_user.role.value if hasattr(current_user.role, "value") else current_user.role
|
||||
)
|
||||
return user_role_value in required_roles
|
||||
|
||||
|
||||
@router.get("", response_model=dict)
|
||||
async def list_users(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
role: str = None,
|
||||
is_active: bool = None,
|
||||
search: str = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not check_permission(current_user, ["super_admin", "admin"]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions",
|
||||
)
|
||||
|
||||
# Build WHERE clause
|
||||
where_clauses = []
|
||||
params = {}
|
||||
if role:
|
||||
where_clauses.append("role = :role")
|
||||
params["role"] = role
|
||||
if is_active is not None:
|
||||
where_clauses.append("is_active = :is_active")
|
||||
params["is_active"] = is_active
|
||||
if search:
|
||||
where_clauses.append("(username ILIKE :search OR email ILIKE :search)")
|
||||
params["search"] = f"%{search}%"
|
||||
|
||||
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
query = text(
|
||||
f"SELECT id, username, email, role, is_active, last_login_at, created_at FROM users WHERE {where_sql} ORDER BY created_at DESC LIMIT {page_size} OFFSET {offset}"
|
||||
)
|
||||
count_query = text(f"SELECT COUNT(*) FROM users WHERE {where_sql}")
|
||||
|
||||
result = await db.execute(query, params)
|
||||
users = result.fetchall()
|
||||
|
||||
count_result = await db.execute(count_query, params)
|
||||
total = count_result.scalar()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"data": [
|
||||
{
|
||||
"id": u[0],
|
||||
"username": u[1],
|
||||
"email": u[2],
|
||||
"role": u[3],
|
||||
"is_active": u[4],
|
||||
"last_login_at": u[5],
|
||||
"created_at": u[6],
|
||||
}
|
||||
for u in users
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=dict)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not check_permission(current_user, ["super_admin", "admin"]) and current_user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, username, email, role, is_active, last_login_at, created_at FROM users WHERE id = :id"
|
||||
),
|
||||
{"id": user_id},
|
||||
)
|
||||
user = result.fetchone()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user[0],
|
||||
"username": user[1],
|
||||
"email": user[2],
|
||||
"role": user[3],
|
||||
"is_active": user[4],
|
||||
"last_login_at": user[5],
|
||||
"created_at": user[6],
|
||||
}
|
||||
|
||||
|
||||
@router.post("", response_model=dict, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
user_data: UserCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not check_permission(current_user, ["super_admin"]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only super_admin can create users",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
text("SELECT id FROM users WHERE username = :username OR email = :email"),
|
||||
{"username": user_data.username, "email": user_data.email},
|
||||
)
|
||||
if result.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username or email already exists",
|
||||
)
|
||||
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
|
||||
await db.execute(
|
||||
text("""INSERT INTO users (username, email, password_hash, role, is_active, created_at, updated_at)
|
||||
VALUES (:username, :email, :password_hash, :role, :is_active, NOW(), NOW())"""),
|
||||
{
|
||||
"username": user_data.username,
|
||||
"email": user_data.email,
|
||||
"password_hash": hashed_password,
|
||||
"role": user_data.role,
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Get the inserted user ID
|
||||
result = await db.execute(
|
||||
text("SELECT id FROM users WHERE username = :username"),
|
||||
{"username": user_data.username},
|
||||
)
|
||||
new_user = result.fetchone()
|
||||
|
||||
if new_user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create user",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": new_user[0],
|
||||
"username": user_data.username,
|
||||
"email": user_data.email,
|
||||
"role": user_data.role,
|
||||
"is_active": True,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{user_id}")
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_data: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not check_permission(current_user, ["super_admin", "admin"]) and current_user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions",
|
||||
)
|
||||
|
||||
if not check_permission(current_user, ["super_admin"]) and user_data.role is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only super_admin can change user role",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
text("SELECT id FROM users WHERE id = :id"),
|
||||
{"id": user_id},
|
||||
)
|
||||
if not result.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
update_fields = []
|
||||
params = {"id": user_id}
|
||||
if user_data.email is not None:
|
||||
update_fields.append("email = :email")
|
||||
params["email"] = user_data.email
|
||||
if user_data.role is not None:
|
||||
update_fields.append("role = :role")
|
||||
params["role"] = user_data.role
|
||||
if user_data.is_active is not None:
|
||||
update_fields.append("is_active = :is_active")
|
||||
params["is_active"] = user_data.is_active
|
||||
|
||||
if update_fields:
|
||||
update_fields.append("updated_at = NOW()")
|
||||
query = text(f"UPDATE users SET {', '.join(update_fields)} WHERE id = :id")
|
||||
await db.execute(query, params)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "User updated successfully"}
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not check_permission(current_user, ["super_admin"]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only super_admin can delete users",
|
||||
)
|
||||
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete yourself",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
text("SELECT id FROM users WHERE id = :id"),
|
||||
{"id": user_id},
|
||||
)
|
||||
if not result.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
await db.execute(
|
||||
text("DELETE FROM users WHERE id = :id"),
|
||||
{"id": user_id},
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "User deleted successfully"}
|
||||
99
backend/app/api/v1/websocket.py
Normal file
99
backend/app/api/v1/websocket.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""WebSocket API endpoints"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.websocket.manager import manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def authenticate_token(token: str) -> Optional[dict]:
|
||||
"""Authenticate WebSocket connection via token"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
if payload.get("type") != "access":
|
||||
logger.warning(f"WebSocket auth failed: wrong token type")
|
||||
return None
|
||||
return payload
|
||||
except JWTError as e:
|
||||
logger.warning(f"WebSocket auth failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(...),
|
||||
):
|
||||
"""WebSocket endpoint for real-time data"""
|
||||
logger.info(f"WebSocket connection attempt with token: {token[:20]}...")
|
||||
payload = await authenticate_token(token)
|
||||
if payload is None:
|
||||
logger.warning("WebSocket authentication failed, closing connection")
|
||||
await websocket.close(code=4001)
|
||||
return
|
||||
|
||||
user_id = str(payload.get("sub"))
|
||||
await manager.connect(websocket, user_id)
|
||||
|
||||
try:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "connection_established",
|
||||
"data": {
|
||||
"connection_id": f"conn_{user_id}",
|
||||
"server_version": settings.VERSION,
|
||||
"heartbeat_interval": 30,
|
||||
"supported_channels": [
|
||||
"gpu_clusters",
|
||||
"submarine_cables",
|
||||
"ixp_nodes",
|
||||
"alerts",
|
||||
"dashboard",
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive_json(), timeout=30)
|
||||
|
||||
if data.get("type") == "heartbeat":
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "heartbeat",
|
||||
"data": {"action": "pong", "timestamp": datetime.utcnow().isoformat()},
|
||||
}
|
||||
)
|
||||
elif data.get("type") == "subscribe":
|
||||
channels = data.get("data", {}).get("channels", [])
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "subscription_confirmed",
|
||||
"data": {"action": "subscribe", "channels": channels},
|
||||
}
|
||||
)
|
||||
elif data.get("type") == "control_frame":
|
||||
await websocket.send_json(
|
||||
{"type": "control_acknowledged", "data": {"received": True}}
|
||||
)
|
||||
else:
|
||||
await websocket.send_json({"type": "ack", "data": {"received": True}})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await websocket.send_json({"type": "heartbeat", "data": {"action": "ping"}})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.disconnect(websocket, user_id)
|
||||
BIN
backend/app/core/__pycache__/cache.cpython-311.pyc
Normal file
BIN
backend/app/core/__pycache__/cache.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/core/__pycache__/config.cpython-311.pyc
Normal file
BIN
backend/app/core/__pycache__/config.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/core/__pycache__/security.cpython-311.pyc
Normal file
BIN
backend/app/core/__pycache__/security.cpython-311.pyc
Normal file
Binary file not shown.
128
backend/app/core/cache.py
Normal file
128
backend/app/core/cache.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Redis caching service"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Any
|
||||
|
||||
import redis
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Lazy Redis client initialization
|
||||
class _RedisClient:
|
||||
_client = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls):
|
||||
if cls._client is None:
|
||||
# Parse REDIS_URL or use default
|
||||
redis_url = settings.REDIS_URL
|
||||
if redis_url.startswith("redis://"):
|
||||
cls._client = redis.from_url(redis_url, decode_responses=True)
|
||||
else:
|
||||
cls._client = redis.Redis(
|
||||
host=settings.REDIS_SERVER,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
return cls._client
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""Redis caching service with JSON serialization"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = _RedisClient.get_client()
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache"""
|
||||
try:
|
||||
value = self.client.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error: {e}")
|
||||
return None
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
expire_seconds: int = 300,
|
||||
) -> bool:
|
||||
"""Set value in cache with expiration"""
|
||||
try:
|
||||
serialized = json.dumps(value, default=str)
|
||||
return self.client.setex(key, expire_seconds, serialized)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error: {e}")
|
||||
return False
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete key from cache"""
|
||||
try:
|
||||
return self.client.delete(key) > 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache delete error: {e}")
|
||||
return False
|
||||
|
||||
def delete_pattern(self, pattern: str) -> int:
|
||||
"""Delete all keys matching pattern"""
|
||||
try:
|
||||
keys = self.client.keys(pattern)
|
||||
if keys:
|
||||
return self.client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache delete_pattern error: {e}")
|
||||
return 0
|
||||
|
||||
def get_or_set(
|
||||
self,
|
||||
key: str,
|
||||
fallback: callable,
|
||||
expire_seconds: int = 300,
|
||||
) -> Optional[Any]:
|
||||
"""Get value from cache or set it using fallback"""
|
||||
value = self.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
value = fallback()
|
||||
if value is not None:
|
||||
self.set(key, value, expire_seconds)
|
||||
return value
|
||||
|
||||
def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""Invalidate all keys matching pattern"""
|
||||
return self.delete_pattern(pattern)
|
||||
|
||||
|
||||
cache = CacheService()
|
||||
|
||||
|
||||
def cached(expire_seconds: int = 300, key_prefix: str = ""):
|
||||
"""Decorator for caching function results"""
|
||||
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
cache_key = f"{key_prefix}:{func.__name__}:{args}:{kwargs}"
|
||||
cache_key = cache_key.replace(":", "_").replace(" ", "")
|
||||
|
||||
cached_value = cache.get(cache_key)
|
||||
if cached_value is not None:
|
||||
return cached_value
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
cache.set(cache_key, result, expire_seconds)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
46
backend/app/core/config.py
Normal file
46
backend/app/core/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "Intelligent Planet Plan"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
SECRET_KEY: str = "your-secret-key-change-in-production"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
POSTGRES_SERVER: str = "localhost"
|
||||
POSTGRES_USER: str = "postgres"
|
||||
POSTGRES_PASSWORD: str = "postgres"
|
||||
POSTGRES_DB: str = "planet_db"
|
||||
DATABASE_URL: str = f"postgresql+asyncpg://postgres:postgres@postgres:5432/planet_db"
|
||||
|
||||
REDIS_SERVER: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_DB: int = 0
|
||||
|
||||
CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
|
||||
|
||||
@property
|
||||
def REDIS_URL(self) -> str:
|
||||
return os.getenv(
|
||||
"REDIS_URL", f"redis://{self.REDIS_SERVER}:{self.REDIS_PORT}/{self.REDIS_DB}"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
162
backend/app/core/security.py
Normal file
162
backend/app/core/security.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
import redis
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
|
||||
oauth2_scheme = HTTPBearer()
|
||||
|
||||
|
||||
class _RedisClient:
|
||||
_client = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls):
|
||||
if cls._client is None:
|
||||
redis_url = settings.REDIS_URL
|
||||
if redis_url.startswith("redis://"):
|
||||
cls._client = redis.from_url(redis_url, decode_responses=True)
|
||||
else:
|
||||
cls._client = redis.Redis(
|
||||
host=settings.REDIS_SERVER,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
return cls._client
|
||||
|
||||
|
||||
redis_client = _RedisClient.get_client()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
if "sub" in to_encode:
|
||||
to_encode["sub"] = str(to_encode["sub"])
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
if "sub" in to_encode:
|
||||
to_encode["sub"] = str(to_encode["sub"])
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
if redis_client.sismember("blacklisted_tokens", token):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has been revoked",
|
||||
)
|
||||
payload = decode_token(token)
|
||||
if payload is None or payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, username, email, password_hash, role, is_active FROM users WHERE id = :id"
|
||||
),
|
||||
{"id": int(user_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is None or not row[5]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
user = User()
|
||||
user.id = row[0]
|
||||
user.username = row[1]
|
||||
user.email = row[2]
|
||||
user.password_hash = row[3]
|
||||
user.role = row[4]
|
||||
user.is_active = row[5]
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_refresh(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
payload = decode_token(token)
|
||||
if payload is None or payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, username, email, password_hash, role, is_active FROM users WHERE id = :id"
|
||||
),
|
||||
{"id": int(user_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is None or not row[5]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
user = User()
|
||||
user.id = row[0]
|
||||
user.username = row[1]
|
||||
user.email = row[2]
|
||||
user.password_hash = row[3]
|
||||
user.role = row[4]
|
||||
user.is_active = row[5]
|
||||
return user
|
||||
|
||||
|
||||
def blacklist_token(token: str) -> None:
|
||||
redis_client.sadd("blacklisted_tokens", token)
|
||||
4
backend/app/core/websocket/__init__.py
Normal file
4
backend/app/core/websocket/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""__init__.py for websocket package"""
|
||||
|
||||
from app.core.websocket.manager import manager, ConnectionManager
|
||||
from app.core.websocket.broadcaster import broadcaster, DataBroadcaster
|
||||
BIN
backend/app/core/websocket/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
backend/app/core/websocket/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/app/core/websocket/__pycache__/manager.cpython-311.pyc
Normal file
BIN
backend/app/core/websocket/__pycache__/manager.cpython-311.pyc
Normal file
Binary file not shown.
93
backend/app/core/websocket/broadcaster.py
Normal file
93
backend/app/core/websocket/broadcaster.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Data broadcaster for WebSocket connections"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from app.core.websocket.manager import manager
|
||||
|
||||
|
||||
class DataBroadcaster:
|
||||
"""Periodically broadcasts data to connected WebSocket clients"""
|
||||
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self.tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
async def get_dashboard_stats(self) -> Dict[str, Any]:
|
||||
"""Get dashboard statistics"""
|
||||
return {
|
||||
"total_datasources": 9,
|
||||
"active_datasources": 8,
|
||||
"tasks_today": 45,
|
||||
"success_rate": 97.8,
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
"alerts": {"critical": 0, "warning": 2, "info": 5},
|
||||
}
|
||||
|
||||
async def broadcast_stats(self, interval: int = 5):
|
||||
"""Broadcast dashboard stats periodically"""
|
||||
while self.running:
|
||||
try:
|
||||
stats = await self.get_dashboard_stats()
|
||||
await manager.broadcast(
|
||||
{
|
||||
"type": "data_frame",
|
||||
"channel": "dashboard",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"payload": {"stats": stats},
|
||||
},
|
||||
channel="dashboard",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async def broadcast_alert(self, alert: Dict[str, Any]):
|
||||
"""Broadcast an alert to all connected clients"""
|
||||
await manager.broadcast(
|
||||
{
|
||||
"type": "alert_notification",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": {"alert": alert},
|
||||
}
|
||||
)
|
||||
|
||||
async def broadcast_gpu_update(self, data: Dict[str, Any]):
|
||||
"""Broadcast GPU cluster update"""
|
||||
await manager.broadcast(
|
||||
{
|
||||
"type": "data_frame",
|
||||
"channel": "gpu_clusters",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"payload": data,
|
||||
}
|
||||
)
|
||||
|
||||
async def broadcast_custom(self, channel: str, data: Dict[str, Any]):
|
||||
"""Broadcast custom data to a specific channel"""
|
||||
await manager.broadcast(
|
||||
{
|
||||
"type": "data_frame",
|
||||
"channel": channel,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"payload": data,
|
||||
},
|
||||
channel=channel if channel in manager.active_connections else "all",
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Start all broadcasters"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.tasks["dashboard"] = asyncio.create_task(self.broadcast_stats(5))
|
||||
|
||||
def stop(self):
|
||||
"""Stop all broadcasters"""
|
||||
self.running = False
|
||||
for task in self.tasks.values():
|
||||
task.cancel()
|
||||
self.tasks.clear()
|
||||
|
||||
|
||||
broadcaster = DataBroadcaster()
|
||||
70
backend/app/core/websocket/manager.py
Normal file
70
backend/app/core/websocket/manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""WebSocket Connection Manager"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Set, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import WebSocket
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {} # user_id -> connections
|
||||
self.redis_client: Optional[redis.Redis] = None
|
||||
|
||||
async def connect(self, websocket: WebSocket, user_id: str):
|
||||
await websocket.accept()
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = set()
|
||||
self.active_connections[user_id].add(websocket)
|
||||
|
||||
if self.redis_client is None:
|
||||
redis_url = settings.REDIS_URL
|
||||
if redis_url.startswith("redis://"):
|
||||
self.redis_client = redis.from_url(redis_url, decode_responses=True)
|
||||
else:
|
||||
self.redis_client = redis.Redis(
|
||||
host=settings.REDIS_SERVER,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def disconnect(self, websocket: WebSocket, user_id: str):
|
||||
if user_id in self.active_connections:
|
||||
self.active_connections[user_id].discard(websocket)
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
async def send_personal_message(self, message: dict, user_id: str):
|
||||
if user_id in self.active_connections:
|
||||
for connection in self.active_connections[user_id]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def broadcast(self, message: dict, channel: str = "all"):
|
||||
if channel == "all":
|
||||
for user_id in self.active_connections:
|
||||
await self.send_personal_message(message, user_id)
|
||||
else:
|
||||
await self.send_personal_message(message, channel)
|
||||
|
||||
async def close_all(self):
|
||||
for user_id in self.active_connections:
|
||||
for connection in self.active_connections[user_id]:
|
||||
await connection.close()
|
||||
self.active_connections.clear()
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
async def get_websocket_manager() -> ConnectionManager:
|
||||
return manager
|
||||
BIN
backend/app/db/__pycache__/session.cpython-311.pyc
Normal file
BIN
backend/app/db/__pycache__/session.cpython-311.pyc
Normal file
Binary file not shown.
35
backend/app/db/session.py
Normal file
35
backend/app/db/session.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG if hasattr(settings, "DEBUG") else False,
|
||||
)
|
||||
|
||||
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def init_db():
|
||||
import app.models.user # noqa: F401
|
||||
import app.models.gpu_cluster # noqa: F401
|
||||
import app.models.task # noqa: F401
|
||||
import app.models.datasource # noqa: F401
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
86
backend/app/main.py
Normal file
86
backend/app/main.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.websocket.broadcaster import broadcaster
|
||||
from app.db.session import init_db, async_session_factory
|
||||
from app.api.main import api_router
|
||||
from app.api.v1 import websocket
|
||||
from app.services.scheduler import start_scheduler, stop_scheduler
|
||||
|
||||
|
||||
class WebSocketCORSMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request, call_next):
|
||||
if request.url.path.startswith("/ws") and request.method == "GET":
|
||||
response = await call_next(request)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
return response
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await init_db()
|
||||
start_scheduler()
|
||||
broadcaster.start()
|
||||
yield
|
||||
broadcaster.stop()
|
||||
stop_scheduler()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
description="智能星球计划 - 态势感知系统\n\n## 功能模块\n\n- **用户认证**: JWT-based authentication\n- **数据源管理**: 多源数据采集器管理\n- **任务调度**: 定时任务调度与监控\n- **实时更新**: WebSocket实时数据推送\n- **告警系统**: 多级告警管理\n\n## 数据层级\n\n- **L1**: 核心数据 (TOP500, Epoch AI GPU)\n- **L2**: 扩展数据 (HuggingFace, PeeringDB, 海缆)\n- **L3**: 分析数据\n- **L4**: 决策支持",
|
||||
lifespan=lifespan,
|
||||
docs_url=None,
|
||||
redoc_url="/docs",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(WebSocketCORSMiddleware)
|
||||
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
app.include_router(websocket.router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"version": settings.VERSION,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""API根目录"""
|
||||
return {
|
||||
"name": settings.PROJECT_NAME,
|
||||
"version": settings.VERSION,
|
||||
"docs": "/docs",
|
||||
"redoc": "/redoc",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/scheduler/jobs")
|
||||
async def get_scheduler_jobs():
|
||||
"""获取调度任务列表"""
|
||||
from app.services.scheduler import get_scheduler_jobs
|
||||
|
||||
return {"jobs": get_scheduler_jobs()}
|
||||
15
backend/app/models/__init__.py
Normal file
15
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from app.models.user import User
|
||||
from app.models.gpu_cluster import GPUCluster
|
||||
from app.models.task import CollectionTask
|
||||
from app.models.datasource import DataSource
|
||||
from app.models.alert import Alert, AlertSeverity, AlertStatus
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"GPUCluster",
|
||||
"CollectionTask",
|
||||
"DataSource",
|
||||
"Alert",
|
||||
"AlertSeverity",
|
||||
"AlertStatus",
|
||||
]
|
||||
BIN
backend/app/models/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/alert.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/alert.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/collected_data.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/collected_data.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/datasource.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/datasource.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/datasource_config.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/datasource_config.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/gpu_cluster.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/gpu_cluster.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/task.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/task.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/user.cpython-311.pyc
Normal file
BIN
backend/app/models/__pycache__/user.cpython-311.pyc
Normal file
Binary file not shown.
57
backend/app/models/alert.py
Normal file
57
backend/app/models/alert.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey, Enum as SQLEnum
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class AlertSeverity(str, Enum):
|
||||
CRITICAL = "critical"
|
||||
WARNING = "warning"
|
||||
INFO = "info"
|
||||
|
||||
|
||||
class AlertStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
ACKNOWLEDGED = "acknowledged"
|
||||
RESOLVED = "resolved"
|
||||
|
||||
|
||||
class Alert(Base):
|
||||
__tablename__ = "alerts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
severity = Column(SQLEnum(AlertSeverity), default=AlertSeverity.WARNING)
|
||||
status = Column(SQLEnum(AlertStatus), default=AlertStatus.ACTIVE)
|
||||
datasource_id = Column(Integer, nullable=True, index=True)
|
||||
datasource_name = Column(String(255), nullable=True)
|
||||
message = Column(Text)
|
||||
alert_metadata = Column(Text, nullable=True)
|
||||
acknowledged_by = Column(Integer, nullable=True)
|
||||
resolved_by = Column(Integer, nullable=True)
|
||||
resolution_notes = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
acknowledged_at = Column(DateTime, nullable=True)
|
||||
resolved_at = Column(DateTime, nullable=True)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"severity": self.severity.value if self.severity else None,
|
||||
"status": self.status.value if self.status else None,
|
||||
"datasource_id": self.datasource_id,
|
||||
"datasource_name": self.datasource_name,
|
||||
"message": self.message,
|
||||
"alert_metadata": self.alert_metadata,
|
||||
"acknowledged_by": self.acknowledged_by,
|
||||
"resolved_by": self.resolved_by,
|
||||
"resolution_notes": self.resolution_notes,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"acknowledged_at": self.acknowledged_at.isoformat() if self.acknowledged_at else None,
|
||||
"resolved_at": self.resolved_at.isoformat() if self.resolved_at else None,
|
||||
}
|
||||
80
backend/app/models/collected_data.py
Normal file
80
backend/app/models/collected_data.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Collected Data model for storing data from all collectors"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text, JSON, Index
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class CollectedData(Base):
|
||||
"""Generic model for storing collected data from all sources"""
|
||||
|
||||
__tablename__ = "collected_data"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source = Column(String(100), nullable=False, index=True) # e.g., "top500", "huggingface_models"
|
||||
source_id = Column(String(100), index=True) # Original ID from source, e.g., "rank_1"
|
||||
data_type = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # e.g., "supercomputer", "model", "dataset"
|
||||
|
||||
# Core data fields
|
||||
name = Column(String(500))
|
||||
title = Column(String(500))
|
||||
description = Column(Text)
|
||||
|
||||
# Location data (for geo visualization)
|
||||
country = Column(String(100))
|
||||
city = Column(String(100))
|
||||
latitude = Column(String(50))
|
||||
longitude = Column(String(50))
|
||||
|
||||
# Performance metrics
|
||||
value = Column(String(100)) # Generic value field (Rmax, Rpeak, etc.)
|
||||
unit = Column(String(20))
|
||||
|
||||
# Additional metadata as JSON
|
||||
extra_data = Column(
|
||||
"metadata", JSON, default={}
|
||||
) # Using 'extra_data' as attribute name but 'metadata' as column name
|
||||
|
||||
# Timestamps
|
||||
collected_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
|
||||
reference_date = Column(DateTime(timezone=True)) # Data reference date (e.g., TOP500 list date)
|
||||
|
||||
# Status
|
||||
is_valid = Column(Integer, default=1) # 1=valid, 0=invalid
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
Index("idx_collected_data_source_collected", "source", "collected_at"),
|
||||
Index("idx_collected_data_source_type", "source", "data_type"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CollectedData {self.id}: {self.source}/{self.data_type}>"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"source": self.source,
|
||||
"source_id": self.source_id,
|
||||
"data_type": self.data_type,
|
||||
"name": self.name,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"country": self.country,
|
||||
"city": self.city,
|
||||
"latitude": self.latitude,
|
||||
"longitude": self.longitude,
|
||||
"value": self.value,
|
||||
"unit": self.unit,
|
||||
"metadata": self.extra_data,
|
||||
"collected_at": self.collected_at.isoformat()
|
||||
if self.collected_at is not None
|
||||
else None,
|
||||
"reference_date": self.reference_date.isoformat()
|
||||
if self.reference_date is not None
|
||||
else None,
|
||||
}
|
||||
28
backend/app/models/datasource.py
Normal file
28
backend/app/models/datasource.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Data Source model"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Integer, String, Text
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class DataSource(Base):
|
||||
__tablename__ = "data_sources"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
source = Column(String(100), nullable=False)
|
||||
module = Column(String(10), nullable=False, index=True) # L1, L2, L3, L4
|
||||
priority = Column(String(10), default="P1") # P0, P1, P2
|
||||
frequency_minutes = Column(Integer, default=60)
|
||||
collector_class = Column(String(100), nullable=False)
|
||||
config = Column(Text, default="{}") # JSON config
|
||||
is_active = Column(Boolean, default=True, index=True)
|
||||
last_run_at = Column(DateTime(timezone=True))
|
||||
last_status = Column(String(20))
|
||||
next_run_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DataSource {self.id}: {self.name}>"
|
||||
26
backend/app/models/datasource_config.py
Normal file
26
backend/app/models/datasource_config.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""User-defined Data Source Configuration model"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Integer, String, Text, JSON
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class DataSourceConfig(Base):
|
||||
__tablename__ = "datasource_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
description = Column(Text)
|
||||
source_type = Column(String(50), nullable=False) # http, api, database, etc.
|
||||
endpoint = Column(String(500))
|
||||
auth_type = Column(String(20), default="none") # none, bearer, api_key, basic
|
||||
auth_config = Column(JSON, default={}) # Encrypted credentials
|
||||
headers = Column(JSON, default={})
|
||||
config = Column(JSON, default={}) # Additional config like timeout, retry, etc.
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DataSourceConfig {self.id}: {self.name}>"
|
||||
29
backend/app/models/gpu_cluster.py
Normal file
29
backend/app/models/gpu_cluster.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""GPU Cluster model for L1 data"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, String, Text
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class GPUCluster(Base):
|
||||
__tablename__ = "gpu_clusters"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
time = Column(DateTime(timezone=True), nullable=False)
|
||||
cluster_id = Column(String(100), nullable=False, index=True)
|
||||
name = Column(String(200), nullable=False)
|
||||
country = Column(String(100))
|
||||
city = Column(String(100))
|
||||
latitude = Column(Float)
|
||||
longitude = Column(Float)
|
||||
organization = Column(String(200))
|
||||
gpu_count = Column(Integer)
|
||||
gpu_type = Column(String(100))
|
||||
total_flops = Column(Float)
|
||||
rank = Column(Integer)
|
||||
source = Column(String(50), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GPUCluster {self.cluster_id}: {self.name}>"
|
||||
22
backend/app/models/task.py
Normal file
22
backend/app/models/task.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Collection Task model"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class CollectionTask(Base):
|
||||
__tablename__ = "collection_tasks"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
datasource_id = Column(Integer, nullable=False, index=True)
|
||||
status = Column(String(20), nullable=False) # pending, running, success, failed, cancelled
|
||||
started_at = Column(DateTime(timezone=True))
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
records_processed = Column(Integer, default=0)
|
||||
error_message = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CollectionTask {self.id}: {self.status}>"
|
||||
25
backend/app/models/user.py
Normal file
25
backend/app/models/user.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from sqlalchemy import Boolean, Column, Integer, String, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(50), unique=True, index=True, nullable=False)
|
||||
email = Column(String(255), unique=True, index=True, nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
role = Column(String(20), default="viewer")
|
||||
is_active = Column(Boolean, default=True)
|
||||
last_login_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
def set_password(self, password: str):
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
self.password_hash = get_password_hash(password)
|
||||
BIN
backend/app/schemas/__pycache__/token.cpython-311.pyc
Normal file
BIN
backend/app/schemas/__pycache__/token.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/schemas/__pycache__/user.cpython-311.pyc
Normal file
BIN
backend/app/schemas/__pycache__/user.cpython-311.pyc
Normal file
Binary file not shown.
22
backend/app/schemas/token.py
Normal file
22
backend/app/schemas/token.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: dict
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: int
|
||||
exp: datetime
|
||||
type: str
|
||||
|
||||
|
||||
class TokenRefresh(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
41
backend/app/schemas/user.py
Normal file
41
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str = Field(..., min_length=8)
|
||||
role: str = "viewer"
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
email: Optional[EmailStr] = None
|
||||
role: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
id: int
|
||||
role: str
|
||||
is_active: bool
|
||||
last_login_at: Optional[datetime]
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: int
|
||||
role: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
BIN
backend/app/services/__pycache__/scheduler.cpython-311.pyc
Normal file
BIN
backend/app/services/__pycache__/scheduler.cpython-311.pyc
Normal file
Binary file not shown.
41
backend/app/services/collectors/__init__.py
Normal file
41
backend/app/services/collectors/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""__init__.py for collectors package"""
|
||||
|
||||
from app.services.collectors.base import BaseCollector, HTTPCollector, IntervalCollector
|
||||
from app.services.collectors.registry import collector_registry, CollectorRegistry
|
||||
from app.services.collectors.top500 import TOP500Collector
|
||||
from app.services.collectors.epoch_ai import EpochAIGPUCollector
|
||||
from app.services.collectors.huggingface import (
|
||||
HuggingFaceModelCollector,
|
||||
HuggingFaceDatasetCollector,
|
||||
HuggingFaceSpacesCollector,
|
||||
)
|
||||
from app.services.collectors.peeringdb import (
|
||||
PeeringDBIXPCollector,
|
||||
PeeringDBNetworkCollector,
|
||||
PeeringDBFacilityCollector,
|
||||
)
|
||||
from app.services.collectors.telegeography import (
|
||||
TeleGeographyCableCollector,
|
||||
TeleGeographyLandingPointCollector,
|
||||
TeleGeographyCableSystemCollector,
|
||||
)
|
||||
from app.services.collectors.cloudflare import (
|
||||
CloudflareRadarDeviceCollector,
|
||||
CloudflareRadarTrafficCollector,
|
||||
CloudflareRadarTopASCollector,
|
||||
)
|
||||
|
||||
collector_registry.register(TOP500Collector())
|
||||
collector_registry.register(EpochAIGPUCollector())
|
||||
collector_registry.register(HuggingFaceModelCollector())
|
||||
collector_registry.register(HuggingFaceDatasetCollector())
|
||||
collector_registry.register(HuggingFaceSpacesCollector())
|
||||
collector_registry.register(PeeringDBIXPCollector())
|
||||
collector_registry.register(PeeringDBNetworkCollector())
|
||||
collector_registry.register(PeeringDBFacilityCollector())
|
||||
collector_registry.register(TeleGeographyCableCollector())
|
||||
collector_registry.register(TeleGeographyLandingPointCollector())
|
||||
collector_registry.register(TeleGeographyCableSystemCollector())
|
||||
collector_registry.register(CloudflareRadarDeviceCollector())
|
||||
collector_registry.register(CloudflareRadarTrafficCollector())
|
||||
collector_registry.register(CloudflareRadarTopASCollector())
|
||||
Binary file not shown.
BIN
backend/app/services/collectors/__pycache__/base.cpython-311.pyc
Normal file
BIN
backend/app/services/collectors/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
179
backend/app/services/collectors/base.py
Normal file
179
backend/app/services/collectors/base.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Base collector class for all data sources"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class BaseCollector(ABC):
|
||||
"""Abstract base class for data collectors"""
|
||||
|
||||
name: str = "base_collector"
|
||||
priority: str = "P1"
|
||||
module: str = "L1"
|
||||
frequency_hours: int = 4
|
||||
data_type: str = "generic" # Override in subclass: "supercomputer", "model", "dataset", etc.
|
||||
|
||||
@abstractmethod
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch raw data from source"""
|
||||
pass
|
||||
|
||||
def transform(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Transform raw data to internal format (default: pass through)"""
|
||||
return raw_data
|
||||
|
||||
async def run(self, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Full pipeline: fetch -> transform -> save"""
|
||||
from app.services.collectors.registry import collector_registry
|
||||
from app.models.task import CollectionTask
|
||||
from app.models.collected_data import CollectedData
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
datasource_id = getattr(self, "_datasource_id", 1) # Default to 1 for built-in collectors
|
||||
|
||||
# Check if collector is active
|
||||
if not collector_registry.is_active(self.name):
|
||||
return {"status": "skipped", "reason": "Collector is disabled"}
|
||||
|
||||
# Log task start
|
||||
task = CollectionTask(
|
||||
datasource_id=datasource_id,
|
||||
status="running",
|
||||
started_at=start_time,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
task_id = task.id
|
||||
|
||||
try:
|
||||
raw_data = await self.fetch()
|
||||
data = self.transform(raw_data)
|
||||
|
||||
# Save data to database
|
||||
records_count = await self._save_data(db, data)
|
||||
|
||||
# Log task success
|
||||
task.status = "success"
|
||||
task.records_processed = records_count
|
||||
task.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"task_id": task_id,
|
||||
"records_processed": records_count,
|
||||
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
||||
}
|
||||
except Exception as e:
|
||||
# Log task failure
|
||||
task.status = "failed"
|
||||
task.error_message = str(e)
|
||||
task.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"status": "failed",
|
||||
"task_id": task_id,
|
||||
"error": str(e),
|
||||
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
||||
}
|
||||
|
||||
async def _save_data(self, db: AsyncSession, data: List[Dict[str, Any]]) -> int:
|
||||
"""Save transformed data to database"""
|
||||
from app.models.collected_data import CollectedData
|
||||
|
||||
if not data:
|
||||
return 0
|
||||
|
||||
collected_at = datetime.utcnow()
|
||||
records_added = 0
|
||||
|
||||
for item in data:
|
||||
# Create CollectedData entry
|
||||
record = CollectedData(
|
||||
source=self.name,
|
||||
source_id=item.get("source_id") or item.get("id"),
|
||||
data_type=self.data_type,
|
||||
name=item.get("name"),
|
||||
title=item.get("title"),
|
||||
description=item.get("description"),
|
||||
country=item.get("country"),
|
||||
city=item.get("city"),
|
||||
latitude=str(item.get("latitude", ""))
|
||||
if item.get("latitude") is not None
|
||||
else None,
|
||||
longitude=str(item.get("longitude", ""))
|
||||
if item.get("longitude") is not None
|
||||
else None,
|
||||
value=item.get("value"),
|
||||
unit=item.get("unit"),
|
||||
extra_data=item.get("metadata", {}),
|
||||
collected_at=collected_at,
|
||||
reference_date=datetime.fromisoformat(
|
||||
item.get("reference_date").replace("Z", "+00:00")
|
||||
)
|
||||
if item.get("reference_date")
|
||||
else None,
|
||||
is_valid=1,
|
||||
)
|
||||
db.add(record)
|
||||
records_added += 1
|
||||
|
||||
await db.commit()
|
||||
return records_added
|
||||
|
||||
async def save(self, db: AsyncSession, data: List[Dict[str, Any]]) -> int:
|
||||
"""Save data to database (legacy method, use _save_data instead)"""
|
||||
return await self._save_data(db, data)
|
||||
|
||||
|
||||
class HTTPCollector(BaseCollector):
|
||||
"""Base class for HTTP API collectors"""
|
||||
|
||||
base_url: str = ""
|
||||
headers: Dict[str, str] = {}
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(self.base_url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return self.parse_response(response.json())
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
|
||||
class IntervalCollector(BaseCollector):
|
||||
"""Base class for collectors that run on intervals"""
|
||||
|
||||
async def run(self, db: AsyncSession) -> Dict[str, Any]:
|
||||
return await super().run(db)
|
||||
|
||||
|
||||
async def log_task(
|
||||
db: AsyncSession,
|
||||
datasource_id: int,
|
||||
status: str,
|
||||
records_processed: int = 0,
|
||||
error_message: Optional[str] = None,
|
||||
):
|
||||
"""Log collection task to database"""
|
||||
from app.models.task import CollectionTask
|
||||
|
||||
task = CollectionTask(
|
||||
datasource_id=datasource_id,
|
||||
status=status,
|
||||
records_processed=records_processed,
|
||||
error_message=error_message,
|
||||
started_at=datetime.utcnow(),
|
||||
completed_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
163
backend/app/services/collectors/cloudflare.py
Normal file
163
backend/app/services/collectors/cloudflare.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Cloudflare Radar Traffic Collector
|
||||
|
||||
Collects Internet traffic data from Cloudflare Radar API.
|
||||
https://developers.cloudflare.com/radar/
|
||||
|
||||
Note: Radar API provides free access to global Internet traffic data.
|
||||
Some endpoints require authentication for higher rate limits.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
from app.services.collectors.base import HTTPCollector
|
||||
|
||||
# Cloudflare API token (optional - for higher rate limits)
|
||||
CLOUDFLARE_API_TOKEN = os.environ.get("CLOUDFLARE_API_TOKEN", "")
|
||||
|
||||
|
||||
class CloudflareRadarDeviceCollector(HTTPCollector):
|
||||
"""Collects device type distribution data (mobile vs desktop)"""
|
||||
|
||||
name = "cloudflare_radar_device"
|
||||
priority = "P2"
|
||||
module = "L3"
|
||||
frequency_hours = 24
|
||||
data_type = "device_stats"
|
||||
base_url = "https://api.cloudflare.com/client/v4/radar/http/summary/device_type"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.headers = {
|
||||
"User-Agent": "Planet-Intelligence-System/1.0 (Python/collector)",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if CLOUDFLARE_API_TOKEN:
|
||||
self.headers["Authorization"] = f"Bearer {CLOUDFLARE_API_TOKEN}"
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse Cloudflare Radar device type response"""
|
||||
data = []
|
||||
result = response.get("result", {})
|
||||
summary = result.get("summary_0", {})
|
||||
|
||||
try:
|
||||
entry = {
|
||||
"source_id": "cloudflare_radar_device_global",
|
||||
"name": "Global Device Distribution",
|
||||
"country": "GLOBAL",
|
||||
"city": "",
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"metadata": {
|
||||
"desktop_percent": float(summary.get("desktop", 0)),
|
||||
"mobile_percent": float(summary.get("mobile", 0)),
|
||||
"other_percent": float(summary.get("other", 0)),
|
||||
"date_range": result.get("meta", {}).get("dateRange", {}),
|
||||
},
|
||||
"reference_date": datetime.utcnow().isoformat(),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
pass
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class CloudflareRadarTrafficCollector(HTTPCollector):
|
||||
"""Collects traffic volume trends"""
|
||||
|
||||
name = "cloudflare_radar_traffic"
|
||||
priority = "P2"
|
||||
module = "L3"
|
||||
frequency_hours = 24
|
||||
data_type = "traffic_stats"
|
||||
base_url = "https://api.cloudflare.com/client/v4/radar/http/timeseries/requests"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.headers = {
|
||||
"User-Agent": "Planet-Intelligence-System/1.0 (Python/collector)",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if CLOUDFLARE_API_TOKEN:
|
||||
self.headers["Authorization"] = f"Bearer {CLOUDFLARE_API_TOKEN}"
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse Cloudflare Radar traffic timeseries response"""
|
||||
data = []
|
||||
result = response.get("result", {})
|
||||
timeseries = result.get("requests_0", {}).get("timeseries", [])
|
||||
|
||||
for item in timeseries:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"cloudflare_traffic_{item.get('datetime', '')}",
|
||||
"name": f"Traffic {item.get('datetime', '')[:10]}",
|
||||
"country": "GLOBAL",
|
||||
"city": "",
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"metadata": {
|
||||
"datetime": item.get("datetime"),
|
||||
"requests": item.get("requests"),
|
||||
"visit_duration": item.get("visitDuration"),
|
||||
},
|
||||
"reference_date": item.get("datetime", datetime.utcnow().isoformat()),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class CloudflareRadarTopASCollector(HTTPCollector):
|
||||
"""Collects top autonomous systems by traffic"""
|
||||
|
||||
name = "cloudflare_radar_top_as"
|
||||
priority = "P2"
|
||||
module = "L2"
|
||||
frequency_hours = 24
|
||||
data_type = "as_stats"
|
||||
base_url = "https://api.cloudflare.com/client/v4/radar/http/top/locations"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.headers = {
|
||||
"User-Agent": "Planet-Intelligence-System/1.0 (Python/collector)",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if CLOUDFLARE_API_TOKEN:
|
||||
self.headers["Authorization"] = f"Bearer {CLOUDFLARE_API_TOKEN}"
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse Cloudflare Radar top locations response"""
|
||||
data = []
|
||||
result = response.get("result", {})
|
||||
top_locations = result.get("top_locations_0", [])
|
||||
|
||||
for idx, item in enumerate(top_locations):
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"cloudflare_as_{item.get('rank', idx)}",
|
||||
"name": item.get("location", {}).get("countryName", "Unknown"),
|
||||
"country": item.get("location", {}).get("countryCode", "XX"),
|
||||
"city": item.get("location", {}).get("cityName", ""),
|
||||
"latitude": float(item.get("location", {}).get("latitude", 0)),
|
||||
"longitude": float(item.get("location", {}).get("longitude", 0)),
|
||||
"metadata": {
|
||||
"rank": item.get("rank"),
|
||||
"traffic_share": item.get("trafficShare"),
|
||||
"country_code": item.get("location", {}).get("countryCode"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().isoformat(),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
118
backend/app/services/collectors/epoch_ai.py
Normal file
118
backend/app/services/collectors/epoch_ai.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Epoch AI GPU Clusters Collector
|
||||
|
||||
Collects data from Epoch AI GPU clusters tracking.
|
||||
https://epoch.ai/data/gpu-clusters
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
from bs4 import BeautifulSoup
|
||||
import httpx
|
||||
|
||||
from app.services.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class EpochAIGPUCollector(BaseCollector):
|
||||
name = "epoch_ai_gpu"
|
||||
priority = "P0"
|
||||
module = "L1"
|
||||
frequency_hours = 6
|
||||
data_type = "gpu_cluster"
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch Epoch AI GPU clusters data from webpage"""
|
||||
url = "https://epoch.ai/data/gpu-clusters"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return self.parse_response(response.text)
|
||||
|
||||
def parse_response(self, html: str) -> List[Dict[str, Any]]:
|
||||
"""Parse Epoch AI webpage to extract GPU cluster data"""
|
||||
data = []
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Try to find data table on the page
|
||||
tables = soup.find_all("table")
|
||||
for table in tables:
|
||||
rows = table.find_all("tr")
|
||||
for row in rows[1:]: # Skip header
|
||||
cells = row.find_all(["td", "th"])
|
||||
if len(cells) >= 5:
|
||||
try:
|
||||
cluster_name = cells[0].get_text(strip=True)
|
||||
if not cluster_name or cluster_name in ["Cluster", "System", "Name"]:
|
||||
continue
|
||||
|
||||
location_cell = cells[1].get_text(strip=True) if len(cells) > 1 else ""
|
||||
country, city = self._parse_location(location_cell)
|
||||
|
||||
perf_cell = cells[2].get_text(strip=True) if len(cells) > 2 else ""
|
||||
|
||||
entry = {
|
||||
"source_id": f"epoch_{re.sub(r'[^a-zA-Z0-9]', '_', cluster_name.lower())}",
|
||||
"name": cluster_name,
|
||||
"country": country,
|
||||
"city": city,
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
"value": self._parse_performance(perf_cell),
|
||||
"unit": "TFlop/s",
|
||||
"metadata": {
|
||||
"raw_data": perf_cell,
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, IndexError, AttributeError):
|
||||
continue
|
||||
|
||||
# If no table found, return sample data
|
||||
if not data:
|
||||
data = self._get_sample_data()
|
||||
|
||||
return data
|
||||
|
||||
def _parse_location(self, location: str) -> tuple:
|
||||
"""Parse location string into country and city"""
|
||||
if not location:
|
||||
return "", ""
|
||||
if "," in location:
|
||||
parts = location.rsplit(",", 1)
|
||||
city = parts[0].strip()
|
||||
country = parts[1].strip() if len(parts) > 1 else ""
|
||||
return country, city
|
||||
return location, ""
|
||||
|
||||
def _parse_performance(self, perf: str) -> str:
|
||||
"""Parse performance string to extract value"""
|
||||
if not perf:
|
||||
return "0"
|
||||
match = re.search(r"([\d,.]+)\s*(TFlop/s|PFlop/s|GFlop/s)?", perf, re.I)
|
||||
if match:
|
||||
return match.group(1).replace(",", "")
|
||||
match = re.search(r"([\d,.]+)", perf)
|
||||
if match:
|
||||
return match.group(1).replace(",", "")
|
||||
return "0"
|
||||
|
||||
def _get_sample_data(self) -> List[Dict[str, Any]]:
|
||||
"""Return sample data for testing when scraping fails"""
|
||||
return [
|
||||
{
|
||||
"source_id": "epoch_sample_1",
|
||||
"name": "Sample GPU Cluster",
|
||||
"country": "United States",
|
||||
"city": "San Francisco, CA",
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
"value": "1000",
|
||||
"unit": "TFlop/s",
|
||||
"metadata": {
|
||||
"note": "Sample data - Epoch AI page structure may vary",
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
},
|
||||
]
|
||||
136
backend/app/services/collectors/huggingface.py
Normal file
136
backend/app/services/collectors/huggingface.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Hugging Face Model Ecosystem Collector
|
||||
|
||||
Collects data from Hugging Face model hub.
|
||||
https://huggingface.co/models
|
||||
https://huggingface.co/datasets
|
||||
https://huggingface.co/spaces
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.collectors.base import HTTPCollector
|
||||
|
||||
|
||||
class HuggingFaceModelCollector(HTTPCollector):
|
||||
name = "huggingface_models"
|
||||
priority = "P1"
|
||||
module = "L2"
|
||||
frequency_hours = 12
|
||||
data_type = "model"
|
||||
base_url = "https://huggingface.co/api/models"
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse Hugging Face models API response"""
|
||||
data = []
|
||||
models = (
|
||||
response
|
||||
if isinstance(response, list)
|
||||
else response.get("models", response.get("items", []))
|
||||
)
|
||||
|
||||
for item in models[:100]:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"hf_model_{item.get('id', '')}",
|
||||
"name": item.get("id", "Unknown"),
|
||||
"description": (item.get("description", "") or "")[:500],
|
||||
"metadata": {
|
||||
"author": item.get("author"),
|
||||
"likes": item.get("likes"),
|
||||
"downloads": item.get("downloads"),
|
||||
"language": item.get("language"),
|
||||
"tags": (item.get("tags", []) or [])[:10],
|
||||
"pipeline_tag": item.get("pipeline_tag"),
|
||||
"library_name": item.get("library_name"),
|
||||
"created_at": item.get("createdAt"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class HuggingFaceDatasetCollector(HTTPCollector):
|
||||
name = "huggingface_datasets"
|
||||
priority = "P1"
|
||||
module = "L2"
|
||||
frequency_hours = 12
|
||||
data_type = "dataset"
|
||||
base_url = "https://huggingface.co/api/datasets"
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse Hugging Face datasets API response"""
|
||||
data = []
|
||||
datasets = (
|
||||
response
|
||||
if isinstance(response, list)
|
||||
else response.get("datasets", response.get("items", []))
|
||||
)
|
||||
|
||||
for item in datasets[:100]:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"hf_dataset_{item.get('id', '')}",
|
||||
"name": item.get("id", "Unknown"),
|
||||
"description": (item.get("description", "") or "")[:500],
|
||||
"metadata": {
|
||||
"author": item.get("author"),
|
||||
"likes": item.get("likes"),
|
||||
"downloads": item.get("downloads"),
|
||||
"size": item.get("size"),
|
||||
"language": item.get("language"),
|
||||
"tags": (item.get("tags", []) or [])[:10],
|
||||
"created_at": item.get("createdAt"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class HuggingFaceSpacesCollector(HTTPCollector):
|
||||
name = "huggingface_spaces"
|
||||
priority = "P2"
|
||||
module = "L2"
|
||||
frequency_hours = 24
|
||||
data_type = "space"
|
||||
base_url = "https://huggingface.co/api/spaces"
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse Hugging Face Spaces API response"""
|
||||
data = []
|
||||
spaces = (
|
||||
response
|
||||
if isinstance(response, list)
|
||||
else response.get("spaces", response.get("items", []))
|
||||
)
|
||||
|
||||
for item in spaces[:100]:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"hf_space_{item.get('id', '')}",
|
||||
"name": item.get("id", "Unknown"),
|
||||
"description": (item.get("description", "") or "")[:500],
|
||||
"metadata": {
|
||||
"author": item.get("author"),
|
||||
"likes": item.get("likes"),
|
||||
"views": item.get("views"),
|
||||
"sdk": item.get("sdk"),
|
||||
"hardware": item.get("hardware"),
|
||||
"tags": (item.get("tags", []) or [])[:10],
|
||||
"created_at": item.get("createdAt"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
331
backend/app/services/collectors/peeringdb.py
Normal file
331
backend/app/services/collectors/peeringdb.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""PeeringDB IXP Nodes Collector
|
||||
|
||||
Collects data from PeeringDB IXP directory.
|
||||
https://www.peeringdb.com
|
||||
|
||||
Note: PeeringDB API has rate limits:
|
||||
- Anonymous: 20 requests/minute
|
||||
- Authenticated: 40 requests/minute (with API key)
|
||||
|
||||
To get higher limits, set PEERINGDB_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
from app.services.collectors.base import HTTPCollector
|
||||
|
||||
# PeeringDB API key - read from environment variable
|
||||
PEERINGDB_API_KEY = os.environ.get("PEERINGDB_API_KEY", "")
|
||||
|
||||
|
||||
class PeeringDBIXPCollector(HTTPCollector):
|
||||
name = "peeringdb_ixp"
|
||||
priority = "P1"
|
||||
module = "L2"
|
||||
frequency_hours = 24
|
||||
data_type = "ixp"
|
||||
base_url = "https://www.peeringdb.com/api/ix"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Set headers with User-Agent
|
||||
self.headers = {
|
||||
"User-Agent": "Planet-Intelligence-System/1.0 (Python/collector)",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
# API key is added to URL as query parameter
|
||||
if PEERINGDB_API_KEY:
|
||||
self.base_url = f"{self.base_url}?key={PEERINGDB_API_KEY}"
|
||||
|
||||
async def fetch_with_retry(
|
||||
self, max_retries: int = 3, base_delay: float = 2.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch data with exponential backoff for rate limiting"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(self.base_url, headers=self.headers)
|
||||
|
||||
if response.status_code == 429:
|
||||
# Rate limited - wait and retry with exponential backoff
|
||||
delay = base_delay * (2**attempt)
|
||||
print(f"PeeringDB rate limited, waiting {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
last_error = "Rate limited"
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
delay = base_delay * (2**attempt)
|
||||
print(f"PeeringDB rate limited, waiting {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
last_error = "Rate limited"
|
||||
continue
|
||||
raise
|
||||
|
||||
print(f"Warning: PeeringDB collection failed after {max_retries} retries: {last_error}")
|
||||
return {}
|
||||
|
||||
async def collect(self) -> List[Dict[str, Any]]:
|
||||
"""Collect IXP data from PeeringDB with rate limit handling"""
|
||||
response_data = await self.fetch_with_retry()
|
||||
if not response_data:
|
||||
return []
|
||||
return self.parse_response(response_data)
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse PeeringDB IXP API response"""
|
||||
data = []
|
||||
ixps = response.get("data", response.get("ixps", []))
|
||||
|
||||
for item in ixps:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"peeringdb_ixp_{item.get('id', '')}",
|
||||
"name": item.get("name", "Unknown"),
|
||||
"country": item.get("country", "Unknown"),
|
||||
"city": item.get("city", ""),
|
||||
"latitude": self._parse_coordinate(item.get("latitude")),
|
||||
"longitude": self._parse_coordinate(item.get("longitude")),
|
||||
"metadata": {
|
||||
"org_name": item.get("org_name"),
|
||||
"url": item.get("url"),
|
||||
"tech_email": item.get("tech_email"),
|
||||
"tech_phone": item.get("tech_phone"),
|
||||
"network_count": len(item.get("net_set", [])),
|
||||
"created": item.get("created"),
|
||||
"updated": item.get("updated"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().isoformat(),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
|
||||
def _parse_coordinate(self, value: Any) -> float:
|
||||
if value is None:
|
||||
return 0.0
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
|
||||
class PeeringDBNetworkCollector(HTTPCollector):
|
||||
name = "peeringdb_network"
|
||||
priority = "P2"
|
||||
module = "L2"
|
||||
frequency_hours = 48
|
||||
data_type = "network"
|
||||
base_url = "https://www.peeringdb.com/api/net"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.headers = {
|
||||
"User-Agent": "Planet-Intelligence-System/1.0 (Python/collector)",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if PEERINGDB_API_KEY:
|
||||
self.base_url = f"{self.base_url}?key={PEERINGDB_API_KEY}"
|
||||
|
||||
async def fetch_with_retry(
|
||||
self, max_retries: int = 3, base_delay: float = 2.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch data with exponential backoff for rate limiting"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(self.base_url, headers=self.headers)
|
||||
|
||||
if response.status_code == 429:
|
||||
delay = base_delay * (2**attempt)
|
||||
print(f"PeeringDB rate limited, waiting {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
last_error = "Rate limited"
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
delay = base_delay * (2**attempt)
|
||||
print(f"PeeringDB rate limited, waiting {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
last_error = "Rate limited"
|
||||
continue
|
||||
raise
|
||||
|
||||
print(f"Warning: PeeringDB collection failed after {max_retries} retries: {last_error}")
|
||||
return {}
|
||||
|
||||
async def collect(self) -> List[Dict[str, Any]]:
|
||||
"""Collect Network data from PeeringDB with rate limit handling"""
|
||||
response_data = await self.fetch_with_retry()
|
||||
if not response_data:
|
||||
return []
|
||||
return self.parse_response(response_data)
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse PeeringDB Network API response"""
|
||||
data = []
|
||||
networks = response.get("data", response.get("networks", []))
|
||||
|
||||
for item in networks:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"peeringdb_net_{item.get('id', '')}",
|
||||
"name": item.get("name", "Unknown"),
|
||||
"country": item.get("country", "Unknown"),
|
||||
"city": item.get("city", ""),
|
||||
"latitude": self._parse_coordinate(item.get("latitude")),
|
||||
"longitude": self._parse_coordinate(item.get("longitude")),
|
||||
"metadata": {
|
||||
"asn": item.get("asn"),
|
||||
"irr_as_set": item.get("irr_as_set"),
|
||||
"url": item.get("url"),
|
||||
"info_type": item.get("info_type"),
|
||||
"info_traffic": item.get("info_traffic"),
|
||||
"info_ratio": item.get("info_ratio"),
|
||||
"ix_count": len(item.get("ix_set", [])),
|
||||
"created": item.get("created"),
|
||||
"updated": item.get("updated"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().isoformat(),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
|
||||
def _parse_coordinate(self, value: Any) -> float:
|
||||
if value is None:
|
||||
return 0.0
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
|
||||
class PeeringDBFacilityCollector(HTTPCollector):
|
||||
name = "peeringdb_facility"
|
||||
priority = "P2"
|
||||
module = "L2"
|
||||
frequency_hours = 48
|
||||
data_type = "facility"
|
||||
base_url = "https://www.peeringdb.com/api/fac"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.headers = {
|
||||
"User-Agent": "Planet-Intelligence-System/1.0 (Python/collector)",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if PEERINGDB_API_KEY:
|
||||
self.base_url = f"{self.base_url}?key={PEERINGDB_API_KEY}"
|
||||
|
||||
async def fetch_with_retry(
|
||||
self, max_retries: int = 3, base_delay: float = 2.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch data with exponential backoff for rate limiting"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(self.base_url, headers=self.headers)
|
||||
|
||||
if response.status_code == 429:
|
||||
delay = base_delay * (2**attempt)
|
||||
print(f"PeeringDB rate limited, waiting {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
last_error = "Rate limited"
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
delay = base_delay * (2**attempt)
|
||||
print(f"PeeringDB rate limited, waiting {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
last_error = "Rate limited"
|
||||
continue
|
||||
raise
|
||||
|
||||
print(f"Warning: PeeringDB collection failed after {max_retries} retries: {last_error}")
|
||||
return {}
|
||||
|
||||
async def collect(self) -> List[Dict[str, Any]]:
|
||||
"""Collect Facility data from PeeringDB with rate limit handling"""
|
||||
response_data = await self.fetch_with_retry()
|
||||
if not response_data:
|
||||
return []
|
||||
return self.parse_response(response_data)
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Parse PeeringDB Facility API response"""
|
||||
data = []
|
||||
facilities = response.get("data", response.get("facilities", []))
|
||||
|
||||
for item in facilities:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"peeringdb_fac_{item.get('id', '')}",
|
||||
"name": item.get("name", "Unknown"),
|
||||
"country": item.get("country", "Unknown"),
|
||||
"city": item.get("city", ""),
|
||||
"latitude": self._parse_coordinate(item.get("latitude")),
|
||||
"longitude": self._parse_coordinate(item.get("longitude")),
|
||||
"metadata": {
|
||||
"org_name": item.get("org_name"),
|
||||
"address": item.get("address"),
|
||||
"url": item.get("url"),
|
||||
"rack_count": item.get("rack_count"),
|
||||
"power": item.get("power"),
|
||||
"network_count": len(item.get("net_set", [])),
|
||||
"created": item.get("created"),
|
||||
"updated": item.get("updated"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().isoformat(),
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
return data
|
||||
|
||||
def _parse_coordinate(self, value: Any) -> float:
|
||||
if value is None:
|
||||
return 0.0
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
return 0.0
|
||||
43
backend/app/services/collectors/registry.py
Normal file
43
backend/app/services/collectors/registry.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Collector registry for managing all data collectors"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from app.services.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class CollectorRegistry:
|
||||
"""Registry for all data collectors"""
|
||||
|
||||
_collectors: Dict[str, BaseCollector] = {}
|
||||
_active_collectors: set = set()
|
||||
|
||||
@classmethod
|
||||
def register(cls, collector: BaseCollector):
|
||||
"""Register a collector"""
|
||||
cls._collectors[collector.name] = collector
|
||||
cls._active_collectors.add(collector.name)
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> Optional[BaseCollector]:
|
||||
"""Get a collector by name"""
|
||||
return cls._collectors.get(name)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Dict[str, BaseCollector]:
|
||||
"""Get all collectors"""
|
||||
return cls._collectors.copy()
|
||||
|
||||
@classmethod
|
||||
def is_active(cls, name: str) -> bool:
|
||||
"""Check if a collector is active"""
|
||||
return name in cls._active_collectors
|
||||
|
||||
@classmethod
|
||||
def set_active(cls, name: str, active: bool = True):
|
||||
"""Set collector active status"""
|
||||
if active:
|
||||
cls._active_collectors.add(name)
|
||||
else:
|
||||
cls._active_collectors.discard(name)
|
||||
|
||||
|
||||
collector_registry = CollectorRegistry()
|
||||
286
backend/app/services/collectors/telegeography.py
Normal file
286
backend/app/services/collectors/telegeography.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""TeleGeography Submarine Cables Collector
|
||||
|
||||
Collects data from TeleGeography submarine cable database.
|
||||
Uses Wayback Machine as backup data source since live data requires JavaScript rendering.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
from bs4 import BeautifulSoup
|
||||
import httpx
|
||||
|
||||
from app.services.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class TeleGeographyCableCollector(BaseCollector):
|
||||
name = "telegeography_cables"
|
||||
priority = "P1"
|
||||
module = "L2"
|
||||
frequency_hours = 168 # 7 days
|
||||
data_type = "submarine_cable"
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch submarine cable data from Wayback Machine"""
|
||||
# Try multiple data sources
|
||||
sources = [
|
||||
# Wayback Machine archive of TeleGeography
|
||||
"https://web.archive.org/web/2024/https://www.submarinecablemap.com/api/v3/cable",
|
||||
# Alternative: Try scraping the page
|
||||
"https://www.submarinecablemap.com",
|
||||
]
|
||||
|
||||
for url in sources:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Check if response is JSON
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" in content_type or url.endswith(".json"):
|
||||
return self.parse_response(response.json())
|
||||
else:
|
||||
# It's HTML, try to scrape
|
||||
data = self.scrape_cables_from_html(response.text)
|
||||
if data:
|
||||
return data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Fallback to sample data
|
||||
return self._get_sample_data()
|
||||
|
||||
def scrape_cables_from_html(self, html: str) -> List[Dict[str, Any]]:
|
||||
"""Try to extract cable data from HTML page"""
|
||||
data = []
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Look for embedded JSON data in scripts
|
||||
scripts = soup.find_all("script")
|
||||
for script in scripts:
|
||||
text = script.string or ""
|
||||
if "cable" in text.lower() and ("{" in text or "[" in text):
|
||||
# Try to find JSON data
|
||||
match = re.search(r"\[.+\]", text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
potential_data = json.loads(match.group())
|
||||
if isinstance(potential_data, list):
|
||||
return potential_data
|
||||
except:
|
||||
pass
|
||||
|
||||
return data
|
||||
|
||||
def parse_response(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Parse submarine cable data"""
|
||||
result = []
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
for item in data:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"telegeo_cable_{item.get('id', item.get('cable_id', ''))}",
|
||||
"name": item.get("name", item.get("cable_name", "Unknown")),
|
||||
"country": "",
|
||||
"city": "",
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
"value": str(item.get("length", item.get("length_km", 0))),
|
||||
"unit": "km",
|
||||
"metadata": {
|
||||
"owner": item.get("owner"),
|
||||
"operator": item.get("operator"),
|
||||
"length_km": item.get("length", item.get("length_km")),
|
||||
"rfs": item.get("rfs"),
|
||||
"status": item.get("status", "active"),
|
||||
"cable_type": item.get("type", "fiber optic"),
|
||||
"capacity_tbps": item.get("capacity"),
|
||||
"url": item.get("url"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
result.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
if not result:
|
||||
result = self._get_sample_data()
|
||||
|
||||
return result
|
||||
|
||||
def _get_sample_data(self) -> List[Dict[str, Any]]:
|
||||
"""Return sample submarine cable data"""
|
||||
return [
|
||||
{
|
||||
"source_id": "telegeo_sample_1",
|
||||
"name": "2Africa",
|
||||
"country": "",
|
||||
"city": "",
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
"value": "45000",
|
||||
"unit": "km",
|
||||
"metadata": {
|
||||
"note": "Sample data - TeleGeography requires browser/scraper for live data",
|
||||
"owner": "Meta, Orange, Vodafone, etc.",
|
||||
"status": "active",
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
},
|
||||
{
|
||||
"source_id": "telegeo_sample_2",
|
||||
"name": "Asia Connect Cable 1",
|
||||
"country": "",
|
||||
"city": "",
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
"value": "12000",
|
||||
"unit": "km",
|
||||
"metadata": {
|
||||
"note": "Sample data",
|
||||
"owner": "Alibaba, NEC",
|
||||
"status": "planned",
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TeleGeographyLandingPointCollector(BaseCollector):
|
||||
name = "telegeography_landing"
|
||||
priority = "P2"
|
||||
module = "L2"
|
||||
frequency_hours = 168
|
||||
data_type = "landing_point"
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch landing point data from GitHub mirror"""
|
||||
url = "https://raw.githubusercontent.com/lintaojlu/submarine_cable_information/main/landing_point.json"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return self.parse_response(response.json())
|
||||
|
||||
def parse_response(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Parse landing point data"""
|
||||
result = []
|
||||
|
||||
for item in data:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"telegeo_lp_{item.get('id', '')}",
|
||||
"name": item.get("name", "Unknown"),
|
||||
"country": item.get("country", "Unknown"),
|
||||
"city": item.get("city", item.get("name", "")),
|
||||
"latitude": str(item.get("latitude", "")),
|
||||
"longitude": str(item.get("longitude", "")),
|
||||
"value": "",
|
||||
"unit": "",
|
||||
"metadata": {
|
||||
"cable_count": len(item.get("cables", [])),
|
||||
"url": item.get("url"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
result.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
if not result:
|
||||
result = self._get_sample_data()
|
||||
|
||||
return result
|
||||
|
||||
def _get_sample_data(self) -> List[Dict[str, Any]]:
|
||||
"""Return sample landing point data"""
|
||||
return [
|
||||
{
|
||||
"source_id": "telegeo_lp_sample_1",
|
||||
"name": "Sample Landing Point",
|
||||
"country": "United States",
|
||||
"city": "Los Angeles, CA",
|
||||
"latitude": "34.0522",
|
||||
"longitude": "-118.2437",
|
||||
"value": "",
|
||||
"unit": "",
|
||||
"metadata": {"note": "Sample data"},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TeleGeographyCableSystemCollector(BaseCollector):
|
||||
name = "telegeography_systems"
|
||||
priority = "P2"
|
||||
module = "L2"
|
||||
frequency_hours = 168
|
||||
data_type = "cable_system"
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch cable system data"""
|
||||
url = "https://raw.githubusercontent.com/lintaojlu/submarine_cable_information/main/cable.json"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return self.parse_response(response.json())
|
||||
|
||||
def parse_response(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Parse cable system data"""
|
||||
result = []
|
||||
|
||||
for item in data:
|
||||
try:
|
||||
entry = {
|
||||
"source_id": f"telegeo_sys_{item.get('id', item.get('cable_id', ''))}",
|
||||
"name": item.get("name", item.get("cable_name", "Unknown")),
|
||||
"country": "",
|
||||
"city": "",
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
"value": str(item.get("length", 0)),
|
||||
"unit": "km",
|
||||
"metadata": {
|
||||
"owner": item.get("owner"),
|
||||
"operator": item.get("operator"),
|
||||
"route": item.get("route"),
|
||||
"countries": item.get("countries", []),
|
||||
"length_km": item.get("length"),
|
||||
"rfs": item.get("rfs"),
|
||||
"status": item.get("status", "active"),
|
||||
"investment": item.get("investment"),
|
||||
"url": item.get("url"),
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
result.append(entry)
|
||||
except (ValueError, TypeError, KeyError):
|
||||
continue
|
||||
|
||||
if not result:
|
||||
result = self._get_sample_data()
|
||||
|
||||
return result
|
||||
|
||||
def _get_sample_data(self) -> List[Dict[str, Any]]:
|
||||
"""Return sample cable system data"""
|
||||
return [
|
||||
{
|
||||
"source_id": "telegeo_sys_sample_1",
|
||||
"name": "Sample Cable System",
|
||||
"country": "",
|
||||
"city": "",
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
"value": "5000",
|
||||
"unit": "km",
|
||||
"metadata": {"note": "Sample data"},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
},
|
||||
]
|
||||
230
backend/app/services/collectors/top500.py
Normal file
230
backend/app/services/collectors/top500.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""TOP500 Supercomputer Collector
|
||||
|
||||
Collects data from TOP500 supercomputer rankings.
|
||||
https://top500.org/lists/top500/
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
from bs4 import BeautifulSoup
|
||||
import httpx
|
||||
|
||||
from app.services.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class TOP500Collector(BaseCollector):
|
||||
name = "top500"
|
||||
priority = "P0"
|
||||
module = "L1"
|
||||
frequency_hours = 4
|
||||
data_type = "supercomputer"
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch TOP500 data from website (scraping)"""
|
||||
# Get the latest list page
|
||||
url = "https://top500.org/lists/top500/list/2025/11/"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return self.parse_response(response.text)
|
||||
|
||||
def parse_response(self, html: str) -> List[Dict[str, Any]]:
|
||||
"""Parse TOP500 HTML response"""
|
||||
data = []
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Find the table with TOP500 data
|
||||
table = soup.find("table", {"class": "top500-table"})
|
||||
if not table:
|
||||
# Try alternative table selector
|
||||
table = soup.find("table", {"id": "top500"})
|
||||
|
||||
if not table:
|
||||
# Try to find any table with rank data
|
||||
tables = soup.find_all("table")
|
||||
for t in tables:
|
||||
if t.find(string=re.compile(r"Rank.*System.*Cores.*Rmax", re.I)):
|
||||
table = t
|
||||
break
|
||||
|
||||
if not table:
|
||||
# Fallback: try to extract data from any table
|
||||
tables = soup.find_all("table")
|
||||
if tables:
|
||||
table = tables[0]
|
||||
|
||||
if table:
|
||||
rows = table.find_all("tr")
|
||||
for row in rows[1:]: # Skip header row
|
||||
cells = row.find_all(["td", "th"])
|
||||
if len(cells) >= 6:
|
||||
try:
|
||||
# Parse the row data
|
||||
rank_text = cells[0].get_text(strip=True)
|
||||
if not rank_text or not rank_text.isdigit():
|
||||
continue
|
||||
|
||||
rank = int(rank_text)
|
||||
|
||||
# System name (may contain link)
|
||||
system_cell = cells[1]
|
||||
system_name = system_cell.get_text(strip=True)
|
||||
# Try to get full name from link title or data attribute
|
||||
link = system_cell.find("a")
|
||||
if link and link.get("title"):
|
||||
system_name = link.get("title")
|
||||
|
||||
# Country
|
||||
country_cell = cells[2]
|
||||
country = country_cell.get_text(strip=True)
|
||||
# Try to get country from data attribute or image alt
|
||||
img = country_cell.find("img")
|
||||
if img and img.get("alt"):
|
||||
country = img.get("alt")
|
||||
|
||||
# Extract location (city)
|
||||
city = ""
|
||||
location_text = country_cell.get_text(strip=True)
|
||||
if "(" in location_text and ")" in location_text:
|
||||
city = location_text.split("(")[0].strip()
|
||||
|
||||
# Cores
|
||||
cores = cells[3].get_text(strip=True).replace(",", "")
|
||||
|
||||
# Rmax
|
||||
rmax_text = cells[4].get_text(strip=True)
|
||||
rmax = self._parse_performance(rmax_text)
|
||||
|
||||
# Rpeak
|
||||
rpeak_text = cells[5].get_text(strip=True)
|
||||
rpeak = self._parse_performance(rpeak_text)
|
||||
|
||||
# Power (optional)
|
||||
power = ""
|
||||
if len(cells) >= 7:
|
||||
power = cells[6].get_text(strip=True)
|
||||
|
||||
entry = {
|
||||
"source_id": f"top500_{rank}",
|
||||
"name": system_name,
|
||||
"country": country,
|
||||
"city": city,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"value": str(rmax),
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": rank,
|
||||
"r_peak": rpeak,
|
||||
"power": power,
|
||||
"cores": cores,
|
||||
},
|
||||
"reference_date": "2025-11-01",
|
||||
}
|
||||
data.append(entry)
|
||||
except (ValueError, IndexError, AttributeError) as e:
|
||||
continue
|
||||
|
||||
# If scraping failed, return sample data for testing
|
||||
if not data:
|
||||
data = self._get_sample_data()
|
||||
|
||||
return data
|
||||
|
||||
def _parse_coordinate(self, value: Any) -> float:
|
||||
"""Parse coordinate value"""
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def _parse_performance(self, text: str) -> float:
|
||||
"""Parse performance value from text (handles E, P, T suffixes)"""
|
||||
text = text.strip().upper()
|
||||
multipliers = {
|
||||
"E": 1e18,
|
||||
"P": 1e15,
|
||||
"T": 1e12,
|
||||
"G": 1e9,
|
||||
"M": 1e6,
|
||||
"K": 1e3,
|
||||
}
|
||||
|
||||
match = re.match(r"([\d.]+)\s*([EPTGMK])?F?LOP/?S?", text)
|
||||
if match:
|
||||
value = float(match.group(1))
|
||||
suffix = match.group(2)
|
||||
if suffix:
|
||||
value *= multipliers.get(suffix, 1)
|
||||
return value
|
||||
|
||||
# Try simple float parsing
|
||||
try:
|
||||
return float(text.replace(",", ""))
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
def _get_sample_data(self) -> List[Dict[str, Any]]:
|
||||
"""Return sample data for testing when scraping fails"""
|
||||
return [
|
||||
{
|
||||
"source_id": "top500_1",
|
||||
"name": "El Capitan - HPE Cray EX255a, AMD 4th Gen EPYC 24C 1.8GHz, AMD Instinct MI300A",
|
||||
"country": "United States",
|
||||
"city": "Livermore, CA",
|
||||
"latitude": 37.6819,
|
||||
"longitude": -121.7681,
|
||||
"value": "1742.00",
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": 1,
|
||||
"r_peak": 2746.38,
|
||||
"power": 29581,
|
||||
"cores": 11039616,
|
||||
"manufacturer": "HPE",
|
||||
},
|
||||
"reference_date": "2025-11-01",
|
||||
},
|
||||
{
|
||||
"source_id": "top500_2",
|
||||
"name": "Frontier - HPE Cray EX235a, AMD Optimized 3rd Generation EPYC 64C 2GHz, AMD Instinct MI250X",
|
||||
"country": "United States",
|
||||
"city": "Oak Ridge, TN",
|
||||
"latitude": 36.0107,
|
||||
"longitude": -84.2663,
|
||||
"value": "1353.00",
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": 2,
|
||||
"r_peak": 2055.72,
|
||||
"power": 24607,
|
||||
"cores": 9066176,
|
||||
"manufacturer": "HPE",
|
||||
},
|
||||
"reference_date": "2025-11-01",
|
||||
},
|
||||
{
|
||||
"source_id": "top500_3",
|
||||
"name": "Aurora - HPE Cray EX - Intel Exascale Compute Blade, Xeon CPU Max 9470 52C 2.4GHz, Intel Data Center GPU Max",
|
||||
"country": "United States",
|
||||
"city": "Argonne, IL",
|
||||
"latitude": 41.3784,
|
||||
"longitude": -87.8600,
|
||||
"value": "1012.00",
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": 3,
|
||||
"r_peak": 1980.01,
|
||||
"power": 38698,
|
||||
"cores": 9264128,
|
||||
"manufacturer": "Intel",
|
||||
},
|
||||
"reference_date": "2025-11-01",
|
||||
},
|
||||
]
|
||||
146
backend/app/services/scheduler.py
Normal file
146
backend/app/services/scheduler.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Task Scheduler for running collection jobs"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import async_session_factory
|
||||
from app.services.collectors.registry import collector_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
|
||||
COLLECTOR_TO_ID = {
|
||||
"top500": 1,
|
||||
"epoch_ai_gpu": 2,
|
||||
"huggingface_models": 3,
|
||||
"huggingface_datasets": 4,
|
||||
"huggingface_spaces": 5,
|
||||
"peeringdb_ixp": 6,
|
||||
"peeringdb_network": 7,
|
||||
"peeringdb_facility": 8,
|
||||
"telegeography_cables": 9,
|
||||
"telegeography_landing": 10,
|
||||
"telegeography_systems": 11,
|
||||
}
|
||||
|
||||
|
||||
async def run_collector_task(collector_name: str):
|
||||
"""Run a single collector task"""
|
||||
collector = collector_registry.get(collector_name)
|
||||
if not collector:
|
||||
logger.error(f"Collector not found: {collector_name}")
|
||||
return
|
||||
|
||||
# Get the correct datasource_id
|
||||
datasource_id = COLLECTOR_TO_ID.get(collector_name, 1)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
# Set the datasource_id on the collector instance
|
||||
collector._datasource_id = datasource_id
|
||||
|
||||
logger.info(f"Running collector: {collector_name} (datasource_id={datasource_id})")
|
||||
result = await collector.run(db)
|
||||
logger.info(f"Collector {collector_name} completed: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"Collector {collector_name} failed: {e}")
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
"""Start the scheduler with all registered collectors"""
|
||||
collectors = collector_registry.all()
|
||||
|
||||
for name, collector in collectors.items():
|
||||
if collector_registry.is_active(name):
|
||||
scheduler.add_job(
|
||||
run_collector_task,
|
||||
trigger=IntervalTrigger(hours=collector.frequency_hours),
|
||||
id=name,
|
||||
name=name,
|
||||
replace_existing=True,
|
||||
kwargs={"collector_name": name},
|
||||
)
|
||||
logger.info(f"Scheduled collector: {name} (every {collector.frequency_hours}h)")
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started")
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
"""Stop the scheduler"""
|
||||
scheduler.shutdown()
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
def get_scheduler_jobs() -> list[Dict[str, Any]]:
|
||||
"""Get all scheduled jobs"""
|
||||
jobs = []
|
||||
for job in scheduler.get_jobs():
|
||||
jobs.append(
|
||||
{
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run_time": job.next_run_time.isoformat() if job.next_run_time else None,
|
||||
"trigger": str(job.trigger),
|
||||
}
|
||||
)
|
||||
return jobs
|
||||
|
||||
|
||||
def add_job(collector_name: str, hours: int = 4):
|
||||
"""Add a new scheduled job"""
|
||||
collector = collector_registry.get(collector_name)
|
||||
if not collector:
|
||||
raise ValueError(f"Collector not found: {collector_name}")
|
||||
|
||||
scheduler.add_job(
|
||||
run_collector_task,
|
||||
trigger=IntervalTrigger(hours=hours),
|
||||
id=collector_name,
|
||||
name=collector_name,
|
||||
replace_existing=True,
|
||||
kwargs={"collector_name": collector_name},
|
||||
)
|
||||
logger.info(f"Added scheduled job: {collector_name} (every {hours}h)")
|
||||
|
||||
|
||||
def remove_job(collector_name: str):
|
||||
"""Remove a scheduled job"""
|
||||
scheduler.remove_job(collector_name)
|
||||
logger.info(f"Removed scheduled job: {collector_name}")
|
||||
|
||||
|
||||
def pause_job(collector_name: str):
|
||||
"""Pause a scheduled job"""
|
||||
scheduler.pause_job(collector_name)
|
||||
logger.info(f"Paused job: {collector_name}")
|
||||
|
||||
|
||||
def resume_job(collector_name: str):
|
||||
"""Resume a scheduled job"""
|
||||
scheduler.resume_job(collector_name)
|
||||
logger.info(f"Resumed job: {collector_name}")
|
||||
|
||||
|
||||
def run_collector_now(collector_name: str) -> bool:
|
||||
"""Run a collector immediately (not scheduled)"""
|
||||
collector = collector_registry.get(collector_name)
|
||||
if not collector:
|
||||
logger.error(f"Collector not found: {collector_name}")
|
||||
return False
|
||||
|
||||
try:
|
||||
asyncio.create_task(run_collector_task(collector_name))
|
||||
logger.info(f"Triggered collector: {collector_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger collector {collector_name}: {e}")
|
||||
return False
|
||||
3
backend/app/tasks/__init__.py
Normal file
3
backend/app/tasks/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Tasks package"""
|
||||
|
||||
from app.tasks.scheduler import run_collector_task, run_collector_sync
|
||||
52
backend/app/tasks/scheduler.py
Normal file
52
backend/app/tasks/scheduler.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Celery tasks for data collection"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.db.session import async_session_factory
|
||||
from app.services.collectors.registry import collector_registry
|
||||
|
||||
|
||||
async def run_collector_task(collector_name: str) -> Dict[str, Any]:
|
||||
"""Run a single collector task"""
|
||||
collector = collector_registry.get(collector_name)
|
||||
if not collector:
|
||||
return {"status": "failed", "error": f"Collector {collector_name} not found"}
|
||||
|
||||
if not collector_registry.is_active(collector_name):
|
||||
return {"status": "skipped", "reason": "Collector is disabled"}
|
||||
|
||||
async with async_session_factory() as db:
|
||||
from app.models.task import CollectionTask
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
# Find datasource
|
||||
result = await db.execute(
|
||||
"SELECT id FROM data_sources WHERE collector_class = :class_name",
|
||||
{"class_name": f"{collector.__class__.__name__}"},
|
||||
)
|
||||
datasource = result.fetchone()
|
||||
|
||||
task = CollectionTask(
|
||||
datasource_id=datasource[0] if datasource else 0,
|
||||
status="running",
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
|
||||
result = await collector.run(db)
|
||||
|
||||
task.status = result["status"]
|
||||
task.completed_at = datetime.utcnow()
|
||||
task.records_processed = result.get("records_processed", 0)
|
||||
task.error_message = result.get("error")
|
||||
await db.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_collector_sync(collector_name: str) -> Dict[str, Any]:
|
||||
"""Synchronous wrapper for running collectors"""
|
||||
return asyncio.run(run_collector_task(collector_name))
|
||||
10
backend/pytest.ini
Normal file
10
backend/pytest.ini
Normal file
@@ -0,0 +1,10 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_functions = test_*
|
||||
python_classes = Test*
|
||||
addopts = -v --tb=short
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::PendingDeprecationWarning
|
||||
18
backend/requirements.txt
Normal file
18
backend/requirements.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
sqlalchemy[asyncio]>=2.0.25
|
||||
asyncpg>=0.29.0
|
||||
redis>=5.0.1
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
passlib[bcrypt]>=1.7.4
|
||||
python-multipart>=0.0.6
|
||||
httpx>=0.26.0
|
||||
beautifulsoup4>=4.12.0
|
||||
aiofiles>=23.2.1
|
||||
python-dotenv>=1.0.0
|
||||
email-validator
|
||||
apscheduler>=3.10.4
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.23.0
|
||||
35
backend/scripts/init_admin.py
Normal file
35
backend/scripts/init_admin.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Create default admin user"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from app.core.security import get_password_hash
|
||||
from app.db.session import engine, async_session_factory
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
async def create_admin():
|
||||
from sqlalchemy import text
|
||||
|
||||
async with async_session_factory() as session:
|
||||
result = await session.execute(text("SELECT id FROM users WHERE username = 'admin'"))
|
||||
if result.fetchone():
|
||||
print("Admin user already exists")
|
||||
return
|
||||
|
||||
admin = User(
|
||||
username="admin",
|
||||
email="admin@planet.local",
|
||||
password_hash=get_password_hash("admin123"),
|
||||
role="super_admin",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(admin)
|
||||
await session.commit()
|
||||
print("Admin user created: admin / admin123")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(create_admin())
|
||||
45
backend/scripts/init_db.py
Normal file
45
backend/scripts/init_db.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Create initial admin user with pre-generated hash"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "/app")
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import bcrypt
|
||||
|
||||
|
||||
# Generate proper bcrypt hash
|
||||
ADMIN_PASSWORD_HASH = bcrypt.hashpw("admin123".encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
async def create_admin():
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@postgres:5432/planet_db"
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT id FROM users WHERE username = 'admin'")
|
||||
)
|
||||
if result.fetchone():
|
||||
print("Admin user already exists")
|
||||
return
|
||||
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO users (username, email, password_hash, role, is_active, created_at, updated_at)
|
||||
VALUES ('admin', 'admin@planet.local', :password, 'super_admin', true, NOW(), NOW())
|
||||
"""),
|
||||
{"password": ADMIN_PASSWORD_HASH},
|
||||
)
|
||||
await session.commit()
|
||||
print(f"Admin user created: admin / admin123")
|
||||
print(f"Hash: {ADMIN_PASSWORD_HASH}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(create_admin())
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test configuration"""
|
||||
BIN
backend/tests/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
backend/tests/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/conftest.cpython-311-pytest-9.0.2.pyc
Normal file
BIN
backend/tests/__pycache__/conftest.cpython-311-pytest-9.0.2.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/test_api.cpython-311-pytest-9.0.2.pyc
Normal file
BIN
backend/tests/__pycache__/test_api.cpython-311-pytest-9.0.2.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
103
backend/tests/conftest.py
Normal file
103
backend/tests/conftest.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Pytest configuration and fixtures"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests"""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session"""
|
||||
session = AsyncMock(spec=AsyncSession)
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.close = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_top500_response():
|
||||
"""Sample TOP500 API response"""
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"rank": 1,
|
||||
"system_name": "Frontier",
|
||||
"country": "USA",
|
||||
"city": "Oak Ridge",
|
||||
"latitude": 35.9322,
|
||||
"longitude": -84.3108,
|
||||
"manufacturer": "HPE",
|
||||
"r_max": 1102000.0,
|
||||
"r_peak": 1685000.0,
|
||||
"power": 21510.0,
|
||||
"cores": 8730112,
|
||||
"interconnect": "Slingshot 11",
|
||||
"os": "CentOS",
|
||||
},
|
||||
{
|
||||
"rank": 2,
|
||||
"system_name": "Fugaku",
|
||||
"country": "Japan",
|
||||
"city": "Kobe",
|
||||
"latitude": 34.6913,
|
||||
"longitude": 135.1830,
|
||||
"manufacturer": "Fujitsu",
|
||||
"r_max": 442010.0,
|
||||
"r_peak": 537212.0,
|
||||
"power": 29899.0,
|
||||
"cores": 7630848,
|
||||
"interconnect": "Tofu interconnect D",
|
||||
"os": "RHEL",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_huggingface_response():
|
||||
"""Sample Hugging Face API response"""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"id": "bert-base-uncased",
|
||||
"author": "google",
|
||||
"description": "BERT base model",
|
||||
"likes": 25000,
|
||||
"downloads": 5000000,
|
||||
"language": "en",
|
||||
"tags": ["transformer", "bert"],
|
||||
"pipeline_tag": "feature-extraction",
|
||||
"library_name": "transformers",
|
||||
"createdAt": "2024-01-15T10:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_alert_data():
|
||||
"""Sample alert data"""
|
||||
return {
|
||||
"id": 1,
|
||||
"severity": "warning",
|
||||
"status": "active",
|
||||
"datasource_id": 2,
|
||||
"datasource_name": "Epoch AI",
|
||||
"message": "API response time > 30s",
|
||||
"created_at": "2024-01-20T09:30:00Z",
|
||||
"acknowledged_by": None,
|
||||
}
|
||||
108
backend/tests/test_api.py
Normal file
108
backend/tests/test_api.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""API endpoint tests"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from app.main import app
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
"""Create authentication headers"""
|
||||
token = create_access_token({"sub": "1", "username": "testuser"})
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check():
|
||||
"""Test health check endpoint"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "version" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == settings.PROJECT_NAME
|
||||
assert data["version"] == settings.VERSION
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_stats_without_auth():
|
||||
"""Test dashboard stats requires authentication"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/dashboard/stats")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_stats_with_auth(auth_headers):
|
||||
"""Test dashboard stats with authentication"""
|
||||
with patch("app.api.v1.dashboard.cache.get", return_value=None):
|
||||
with patch("app.api.v1.dashboard.cache.set", return_value=True):
|
||||
with patch("app.db.session.get_db") as mock_get_db:
|
||||
mock_session = AsyncMock()
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalar.return_value = 0
|
||||
mock_result.fetchall.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
async def mock_db_context():
|
||||
yield mock_session
|
||||
|
||||
mock_get_db.return_value = mock_db_context()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/v1/dashboard/stats",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_datasources" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alerts_without_auth():
|
||||
"""Test alerts endpoint requires authentication"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/alerts")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alerts_endpoint_with_auth(auth_headers):
|
||||
"""Test alerts endpoint with authentication"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/alerts", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token():
|
||||
"""Test that invalid token is rejected"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/v1/dashboard/stats",
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
112
backend/tests/test_collectors.py
Normal file
112
backend/tests/test_collectors.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Unit tests for data collectors"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.collectors.top500 import TOP500Collector
|
||||
from app.services.collectors.base import BaseCollector, HTTPCollector
|
||||
|
||||
|
||||
class TestBaseCollector:
|
||||
"""Tests for BaseCollector"""
|
||||
|
||||
def test_base_collector_attributes(self):
|
||||
"""Test base collector has correct default attributes via concrete class"""
|
||||
collector = TOP500Collector()
|
||||
assert collector.name == "top500"
|
||||
assert collector.priority == "P0"
|
||||
assert collector.module == "L1"
|
||||
assert collector.frequency_hours == 4
|
||||
|
||||
|
||||
class TestTOP500Collector:
|
||||
"""Tests for TOP500Collector"""
|
||||
|
||||
def test_parse_coordinate_valid_float(self):
|
||||
"""Test parsing valid float coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate(45.5) == 45.5
|
||||
|
||||
def test_parse_coordinate_valid_string(self):
|
||||
"""Test parsing valid string coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate("45.5") == 45.5
|
||||
|
||||
def test_parse_coordinate_invalid_string(self):
|
||||
"""Test parsing invalid string coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate("invalid") == 0.0
|
||||
|
||||
def test_parse_coordinate_none(self):
|
||||
"""Test parsing None coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate(None) == 0.0
|
||||
|
||||
def test_parse_response_empty(self):
|
||||
"""Test parsing empty response"""
|
||||
collector = TOP500Collector()
|
||||
result = collector.parse_response({"items": []})
|
||||
assert result == []
|
||||
|
||||
def test_parse_response_single_item(self):
|
||||
"""Test parsing single item response"""
|
||||
collector = TOP500Collector()
|
||||
response = {
|
||||
"items": [
|
||||
{
|
||||
"rank": 1,
|
||||
"system_name": "Test Supercomputer",
|
||||
"country": "USA",
|
||||
"city": "San Francisco",
|
||||
"latitude": 37.7749,
|
||||
"longitude": -122.4194,
|
||||
"manufacturer": "Test Corp",
|
||||
"r_max": 100000.0,
|
||||
"r_peak": 150000.0,
|
||||
"power": 5000.0,
|
||||
"cores": 100000,
|
||||
"interconnect": "InfiniBand",
|
||||
"os": "Linux",
|
||||
}
|
||||
]
|
||||
}
|
||||
result = collector.parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0]["cluster_id"] == "top500_1"
|
||||
assert result[0]["name"] == "Test Supercomputer"
|
||||
assert result[0]["country"] == "USA"
|
||||
assert result[0]["rank"] == 1
|
||||
assert result[0]["source"] == "TOP500"
|
||||
|
||||
def test_parse_response_skips_invalid_item(self):
|
||||
"""Test parsing skips items with missing data"""
|
||||
collector = TOP500Collector()
|
||||
response = {
|
||||
"items": [
|
||||
{"rank": 1, "system_name": "Valid"},
|
||||
{"rank": None, "system_name": "Invalid"},
|
||||
]
|
||||
}
|
||||
result = collector.parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Valid"
|
||||
|
||||
|
||||
class TestHTTPCollector:
|
||||
"""Tests for HTTPCollector"""
|
||||
|
||||
def test_http_collector_attributes(self):
|
||||
"""Test HTTP collector has correct default attributes via concrete class"""
|
||||
collector = TOP500Collector()
|
||||
assert collector.base_url == "https://top500.org/api/v1.0/lists/"
|
||||
assert collector.name == "top500"
|
||||
assert collector.priority == "P0"
|
||||
|
||||
def test_collector_has_required_methods(self):
|
||||
"""Test HTTP collector has required methods"""
|
||||
collector = TOP500Collector()
|
||||
assert hasattr(collector, "fetch")
|
||||
assert hasattr(collector, "parse_response")
|
||||
assert callable(collector.fetch)
|
||||
assert callable(collector.parse_response)
|
||||
131
backend/tests/test_models.py
Normal file
131
backend/tests/test_models.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Unit tests for models"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.alert import Alert, AlertSeverity, AlertStatus
|
||||
from app.models.task import CollectionTask
|
||||
|
||||
|
||||
class TestUserModel:
|
||||
"""Tests for User model"""
|
||||
|
||||
def test_user_creation(self):
|
||||
"""Test user model creation"""
|
||||
user = User(
|
||||
id=1,
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
role="admin",
|
||||
is_active=True,
|
||||
)
|
||||
assert user.id == 1
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "test@example.com"
|
||||
assert user.is_active is True
|
||||
|
||||
def test_user_role_assignment(self):
|
||||
"""Test user role assignment"""
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
password_hash="hashed",
|
||||
role="admin",
|
||||
)
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_user_password_hash(self):
|
||||
"""Test user password hash attribute"""
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
)
|
||||
assert user.password_hash == "hashed_password"
|
||||
|
||||
|
||||
class TestAlertModel:
|
||||
"""Tests for Alert model"""
|
||||
|
||||
def test_alert_creation(self):
|
||||
"""Test alert model creation"""
|
||||
alert = Alert(
|
||||
id=1,
|
||||
severity=AlertSeverity.WARNING,
|
||||
status=AlertStatus.ACTIVE,
|
||||
message="Test alert message",
|
||||
datasource_id=1,
|
||||
datasource_name="Test Source",
|
||||
)
|
||||
assert alert.id == 1
|
||||
assert alert.severity == AlertSeverity.WARNING
|
||||
assert alert.status == AlertStatus.ACTIVE
|
||||
assert alert.message == "Test alert message"
|
||||
|
||||
def test_alert_to_dict(self):
|
||||
"""Test alert to_dict method"""
|
||||
alert = Alert(
|
||||
id=1,
|
||||
severity=AlertSeverity.CRITICAL,
|
||||
status=AlertStatus.ACTIVE,
|
||||
message="Critical alert",
|
||||
datasource_id=2,
|
||||
datasource_name="Test Source",
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
)
|
||||
result = alert.to_dict()
|
||||
assert result["id"] == 1
|
||||
assert result["severity"] == "critical"
|
||||
assert result["status"] == "active"
|
||||
assert result["message"] == "Critical alert"
|
||||
assert result["created_at"] == "2024-01-01T12:00:00"
|
||||
|
||||
def test_alert_severity_enum(self):
|
||||
"""Test alert severity enum values"""
|
||||
assert AlertSeverity.CRITICAL.value == "critical"
|
||||
assert AlertSeverity.WARNING.value == "warning"
|
||||
assert AlertSeverity.INFO.value == "info"
|
||||
|
||||
def test_alert_status_enum(self):
|
||||
"""Test alert status enum values"""
|
||||
assert AlertStatus.ACTIVE.value == "active"
|
||||
assert AlertStatus.ACKNOWLEDGED.value == "acknowledged"
|
||||
assert AlertStatus.RESOLVED.value == "resolved"
|
||||
|
||||
|
||||
class TestCollectionTaskModel:
|
||||
"""Tests for CollectionTask model"""
|
||||
|
||||
def test_task_creation(self):
|
||||
"""Test collection task creation"""
|
||||
task = CollectionTask(
|
||||
id=1,
|
||||
datasource_id=1,
|
||||
status="running",
|
||||
records_processed=0,
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
assert task.id == 1
|
||||
assert task.datasource_id == 1
|
||||
assert task.status == "running"
|
||||
|
||||
def test_task_with_records(self):
|
||||
"""Test collection task with records processed"""
|
||||
task = CollectionTask(
|
||||
datasource_id=1,
|
||||
status="success",
|
||||
records_processed=100,
|
||||
)
|
||||
assert task.records_processed == 100
|
||||
|
||||
def test_task_error_message(self):
|
||||
"""Test collection task with error message"""
|
||||
task = CollectionTask(
|
||||
datasource_id=1,
|
||||
status="failed",
|
||||
error_message="Connection timeout",
|
||||
)
|
||||
assert task.error_message == "Connection timeout"
|
||||
113
backend/tests/test_security.py
Normal file
113
backend/tests/test_security.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Unit tests for security module"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from jose import jwt
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Tests for password hashing functions"""
|
||||
|
||||
def test_hash_password(self):
|
||||
"""Test password hashing"""
|
||||
password = "test_password_123"
|
||||
hashed = get_password_hash(password)
|
||||
assert hashed != password
|
||||
assert len(hashed) > 0
|
||||
|
||||
def test_verify_correct_password(self):
|
||||
"""Test verification of correct password"""
|
||||
password = "test_password_123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_incorrect_password(self):
|
||||
"""Test verification of incorrect password"""
|
||||
password = "test_password_123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password("wrong_password", hashed) is False
|
||||
|
||||
def test_hash_is_unique(self):
|
||||
"""Test that hashes are unique for same password"""
|
||||
password = "test_password_123"
|
||||
hash1 = get_password_hash(password)
|
||||
hash2 = get_password_hash(password)
|
||||
assert hash1 != hash2 # bcrypt adds salt
|
||||
|
||||
|
||||
class TestTokenCreation:
|
||||
"""Tests for token creation functions"""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test access token creation"""
|
||||
data = {"sub": "123", "username": "testuser"}
|
||||
token = create_access_token(data)
|
||||
assert token is not None
|
||||
assert len(token) > 0
|
||||
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert payload["sub"] == "123"
|
||||
assert payload["username"] == "testuser"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_create_refresh_token(self):
|
||||
"""Test refresh token creation"""
|
||||
data = {"sub": "123"}
|
||||
token = create_refresh_token(data)
|
||||
assert token is not None
|
||||
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert payload["sub"] == "123"
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_access_token_expiration(self):
|
||||
"""Test access token has correct expiration"""
|
||||
data = {"sub": "123"}
|
||||
token = create_access_token(data)
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
exp_timestamp = payload["exp"]
|
||||
# Token should expire in approximately 15 minutes (accounting for timezone)
|
||||
expected_minutes = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
# The timestamp is in seconds since epoch
|
||||
import time
|
||||
|
||||
now_timestamp = time.time()
|
||||
minutes_diff = (exp_timestamp - now_timestamp) / 60
|
||||
assert expected_minutes - 1 < minutes_diff < expected_minutes + 1
|
||||
|
||||
def test_refresh_token_expiration(self):
|
||||
"""Test refresh token has correct expiration"""
|
||||
data = {"sub": "123"}
|
||||
token = create_refresh_token(data)
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
exp = datetime.fromtimestamp(payload["exp"])
|
||||
now = datetime.utcnow()
|
||||
# Token should expire in approximately 7 days (with some tolerance)
|
||||
delta = exp - now
|
||||
assert delta.days >= 6 # At least 6 days
|
||||
assert delta.days <= 8 # Less than 8 days
|
||||
|
||||
|
||||
class TestJWTSecurity:
|
||||
"""Tests for JWT security features"""
|
||||
|
||||
def test_invalid_token_raises_error(self):
|
||||
"""Test that invalid token raises JWTError"""
|
||||
with pytest.raises(jwt.JWTError):
|
||||
jwt.decode("invalid_token", settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
|
||||
def test_token_with_wrong_secret_raises_error(self):
|
||||
"""Test that token with wrong secret raises error"""
|
||||
data = {"sub": "123"}
|
||||
token = create_access_token(data)
|
||||
with pytest.raises(jwt.JWTError):
|
||||
jwt.decode(token, "wrong_secret", algorithms=[settings.ALGORITHM])
|
||||
Reference in New Issue
Block a user