Refine data management and collection workflows
This commit is contained in:
@@ -7,6 +7,8 @@ import json
|
||||
import csv
|
||||
import io
|
||||
|
||||
from app.core.collected_data_fields import get_metadata_field
|
||||
from app.core.countries import COUNTRY_OPTIONS, get_country_search_variants, normalize_country
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import get_current_user
|
||||
@@ -15,8 +17,119 @@ from app.models.collected_data import CollectedData
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
COUNTRY_SQL = "metadata->>'country'"
|
||||
SEARCHABLE_SQL = [
|
||||
"name",
|
||||
"title",
|
||||
"description",
|
||||
"source",
|
||||
"data_type",
|
||||
"source_id",
|
||||
"metadata::text",
|
||||
]
|
||||
|
||||
|
||||
def parse_multi_values(value: Optional[str]) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
|
||||
def build_in_condition(field_sql: str, values: list[str], param_prefix: str, params: dict) -> str:
|
||||
placeholders = []
|
||||
for index, value in enumerate(values):
|
||||
key = f"{param_prefix}_{index}"
|
||||
params[key] = value
|
||||
placeholders.append(f":{key}")
|
||||
return f"{field_sql} IN ({', '.join(placeholders)})"
|
||||
|
||||
|
||||
def build_search_condition(search: Optional[str], params: dict) -> Optional[str]:
|
||||
if not search:
|
||||
return None
|
||||
|
||||
normalized = search.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
search_terms = [normalized]
|
||||
for variant in get_country_search_variants(normalized):
|
||||
if variant.casefold() not in {term.casefold() for term in search_terms}:
|
||||
search_terms.append(variant)
|
||||
|
||||
conditions = []
|
||||
for index, term in enumerate(search_terms):
|
||||
params[f"search_{index}"] = f"%{term}%"
|
||||
conditions.extend(f"{field} ILIKE :search_{index}" for field in SEARCHABLE_SQL)
|
||||
|
||||
params["search_exact"] = normalized
|
||||
params["search_prefix"] = f"{normalized}%"
|
||||
|
||||
canonical_variants = get_country_search_variants(normalized)
|
||||
canonical = canonical_variants[0] if canonical_variants else None
|
||||
params["country_search_exact"] = canonical or normalized
|
||||
params["country_search_prefix"] = f"{(canonical or normalized)}%"
|
||||
|
||||
return "(" + " OR ".join(conditions) + ")"
|
||||
|
||||
|
||||
def build_search_rank_sql(search: Optional[str]) -> str:
|
||||
if not search or not search.strip():
|
||||
return "0"
|
||||
|
||||
return """
|
||||
CASE
|
||||
WHEN name ILIKE :search_exact THEN 700
|
||||
WHEN name ILIKE :search_prefix THEN 600
|
||||
WHEN title ILIKE :search_exact THEN 500
|
||||
WHEN title ILIKE :search_prefix THEN 400
|
||||
WHEN metadata->>'country' ILIKE :country_search_exact THEN 380
|
||||
WHEN metadata->>'country' ILIKE :country_search_prefix THEN 340
|
||||
WHEN source_id ILIKE :search_exact THEN 350
|
||||
WHEN source ILIKE :search_exact THEN 300
|
||||
WHEN data_type ILIKE :search_exact THEN 250
|
||||
WHEN description ILIKE :search_0 THEN 150
|
||||
WHEN metadata::text ILIKE :search_0 THEN 100
|
||||
WHEN title ILIKE :search_0 THEN 80
|
||||
WHEN name ILIKE :search_0 THEN 60
|
||||
WHEN source ILIKE :search_0 THEN 40
|
||||
WHEN data_type ILIKE :search_0 THEN 30
|
||||
WHEN source_id ILIKE :search_0 THEN 20
|
||||
ELSE 0
|
||||
END
|
||||
"""
|
||||
|
||||
|
||||
def serialize_collected_row(row) -> dict:
|
||||
metadata = row[7]
|
||||
return {
|
||||
"id": row[0],
|
||||
"source": row[1],
|
||||
"source_id": row[2],
|
||||
"data_type": row[3],
|
||||
"name": row[4],
|
||||
"title": row[5],
|
||||
"description": row[6],
|
||||
"country": get_metadata_field(metadata, "country"),
|
||||
"city": get_metadata_field(metadata, "city"),
|
||||
"latitude": get_metadata_field(metadata, "latitude"),
|
||||
"longitude": get_metadata_field(metadata, "longitude"),
|
||||
"value": get_metadata_field(metadata, "value"),
|
||||
"unit": get_metadata_field(metadata, "unit"),
|
||||
"metadata": metadata,
|
||||
"cores": get_metadata_field(metadata, "cores"),
|
||||
"rmax": get_metadata_field(metadata, "rmax"),
|
||||
"rpeak": get_metadata_field(metadata, "rpeak"),
|
||||
"power": get_metadata_field(metadata, "power"),
|
||||
"collected_at": row[8].isoformat() if row[8] else None,
|
||||
"reference_date": row[9].isoformat() if row[9] else None,
|
||||
"is_valid": row[10],
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_collected_data(
|
||||
mode: str = Query("current", description="查询模式: current/history"),
|
||||
source: Optional[str] = Query(None, description="数据源过滤"),
|
||||
data_type: Optional[str] = Query(None, description="数据类型过滤"),
|
||||
country: Optional[str] = Query(None, description="国家过滤"),
|
||||
@@ -27,25 +140,30 @@ async def list_collected_data(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""查询采集的数据列表"""
|
||||
normalized_country = normalize_country(country) if country else None
|
||||
source_values = parse_multi_values(source)
|
||||
data_type_values = parse_multi_values(data_type)
|
||||
|
||||
# Build WHERE clause
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if source:
|
||||
conditions.append("source = :source")
|
||||
params["source"] = source
|
||||
if data_type:
|
||||
conditions.append("data_type = :data_type")
|
||||
params["data_type"] = data_type
|
||||
if country:
|
||||
conditions.append("country = :country")
|
||||
params["country"] = country
|
||||
if search:
|
||||
conditions.append("(name ILIKE :search OR title ILIKE :search)")
|
||||
params["search"] = f"%{search}%"
|
||||
if mode != "history":
|
||||
conditions.append("COALESCE(is_current, TRUE) = TRUE")
|
||||
|
||||
if source_values:
|
||||
conditions.append(build_in_condition("source", source_values, "source", params))
|
||||
if data_type_values:
|
||||
conditions.append(build_in_condition("data_type", data_type_values, "data_type", params))
|
||||
if normalized_country:
|
||||
conditions.append(f"{COUNTRY_SQL} = :country")
|
||||
params["country"] = normalized_country
|
||||
search_condition = build_search_condition(search, params)
|
||||
if search_condition:
|
||||
conditions.append(search_condition)
|
||||
|
||||
where_sql = " AND ".join(conditions) if conditions else "1=1"
|
||||
search_rank_sql = build_search_rank_sql(search)
|
||||
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
@@ -58,11 +176,11 @@ async def list_collected_data(
|
||||
# Query data
|
||||
query = text(f"""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
metadata, collected_at, reference_date, is_valid,
|
||||
{search_rank_sql} AS search_rank
|
||||
FROM collected_data
|
||||
WHERE {where_sql}
|
||||
ORDER BY collected_at DESC
|
||||
ORDER BY search_rank DESC, collected_at DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
params["limit"] = page_size
|
||||
@@ -73,27 +191,7 @@ async def list_collected_data(
|
||||
|
||||
data = []
|
||||
for row in rows:
|
||||
data.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"source": row[1],
|
||||
"source_id": row[2],
|
||||
"data_type": row[3],
|
||||
"name": row[4],
|
||||
"title": row[5],
|
||||
"description": row[6],
|
||||
"country": row[7],
|
||||
"city": row[8],
|
||||
"latitude": row[9],
|
||||
"longitude": row[10],
|
||||
"value": row[11],
|
||||
"unit": row[12],
|
||||
"metadata": row[13],
|
||||
"collected_at": row[14].isoformat() if row[14] else None,
|
||||
"reference_date": row[15].isoformat() if row[15] else None,
|
||||
"is_valid": row[16],
|
||||
}
|
||||
)
|
||||
data.append(serialize_collected_row(row[:11]))
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
@@ -105,16 +203,19 @@ async def list_collected_data(
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_data_summary(
|
||||
mode: str = Query("current", description="查询模式: current/history"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取数据汇总统计"""
|
||||
where_sql = "WHERE COALESCE(is_current, TRUE) = TRUE" if mode != "history" else ""
|
||||
|
||||
# By source and data_type
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT source, data_type, COUNT(*) as count
|
||||
FROM collected_data
|
||||
""" + where_sql + """
|
||||
GROUP BY source, data_type
|
||||
ORDER BY source, data_type
|
||||
""")
|
||||
@@ -138,6 +239,7 @@ async def get_data_summary(
|
||||
text("""
|
||||
SELECT source, COUNT(*) as count
|
||||
FROM collected_data
|
||||
""" + where_sql + """
|
||||
GROUP BY source
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
@@ -153,6 +255,7 @@ async def get_data_summary(
|
||||
|
||||
@router.get("/sources")
|
||||
async def get_data_sources(
|
||||
mode: str = Query("current", description="查询模式: current/history"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
@@ -160,7 +263,9 @@ async def get_data_sources(
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT source FROM collected_data ORDER BY source
|
||||
SELECT DISTINCT source FROM collected_data
|
||||
""" + ("WHERE COALESCE(is_current, TRUE) = TRUE " if mode != "history" else "") + """
|
||||
ORDER BY source
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
@@ -172,6 +277,7 @@ async def get_data_sources(
|
||||
|
||||
@router.get("/types")
|
||||
async def get_data_types(
|
||||
mode: str = Query("current", description="查询模式: current/history"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
@@ -179,7 +285,9 @@ async def get_data_types(
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT data_type FROM collected_data ORDER BY data_type
|
||||
SELECT DISTINCT data_type FROM collected_data
|
||||
""" + ("WHERE COALESCE(is_current, TRUE) = TRUE " if mode != "history" else "") + """
|
||||
ORDER BY data_type
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
@@ -196,17 +304,8 @@ async def get_countries(
|
||||
):
|
||||
"""获取所有国家列表"""
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT country FROM collected_data
|
||||
WHERE country IS NOT NULL AND country != ''
|
||||
ORDER BY country
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
return {
|
||||
"countries": [row[0] for row in rows],
|
||||
"countries": COUNTRY_OPTIONS,
|
||||
}
|
||||
|
||||
|
||||
@@ -221,7 +320,6 @@ async def get_collected_data(
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
FROM collected_data
|
||||
WHERE id = :id
|
||||
@@ -236,25 +334,7 @@ async def get_collected_data(
|
||||
detail="数据不存在",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": row[0],
|
||||
"source": row[1],
|
||||
"source_id": row[2],
|
||||
"data_type": row[3],
|
||||
"name": row[4],
|
||||
"title": row[5],
|
||||
"description": row[6],
|
||||
"country": row[7],
|
||||
"city": row[8],
|
||||
"latitude": row[9],
|
||||
"longitude": row[10],
|
||||
"value": row[11],
|
||||
"unit": row[12],
|
||||
"metadata": row[13],
|
||||
"collected_at": row[14].isoformat() if row[14] else None,
|
||||
"reference_date": row[15].isoformat() if row[15] else None,
|
||||
"is_valid": row[16],
|
||||
}
|
||||
return serialize_collected_row(row)
|
||||
|
||||
|
||||
def build_where_clause(
|
||||
@@ -263,19 +343,21 @@ def build_where_clause(
|
||||
"""Build WHERE clause and params for queries"""
|
||||
conditions = []
|
||||
params = {}
|
||||
source_values = parse_multi_values(source)
|
||||
data_type_values = parse_multi_values(data_type)
|
||||
|
||||
if source:
|
||||
conditions.append("source = :source")
|
||||
params["source"] = source
|
||||
if data_type:
|
||||
conditions.append("data_type = :data_type")
|
||||
params["data_type"] = data_type
|
||||
if country:
|
||||
conditions.append("country = :country")
|
||||
params["country"] = country
|
||||
if search:
|
||||
conditions.append("(name ILIKE :search OR title ILIKE :search)")
|
||||
params["search"] = f"%{search}%"
|
||||
if source_values:
|
||||
conditions.append(build_in_condition("source", source_values, "source", params))
|
||||
if data_type_values:
|
||||
conditions.append(build_in_condition("data_type", data_type_values, "data_type", params))
|
||||
normalized_country = normalize_country(country) if country else None
|
||||
|
||||
if normalized_country:
|
||||
conditions.append(f"{COUNTRY_SQL} = :country")
|
||||
params["country"] = normalized_country
|
||||
search_condition = build_search_condition(search, params)
|
||||
if search_condition:
|
||||
conditions.append(search_condition)
|
||||
|
||||
where_sql = " AND ".join(conditions) if conditions else "1=1"
|
||||
return where_sql, params
|
||||
@@ -283,6 +365,7 @@ def build_where_clause(
|
||||
|
||||
@router.get("/export/json")
|
||||
async def export_json(
|
||||
mode: str = Query("current", description="查询模式: current/history"),
|
||||
source: Optional[str] = Query(None, description="数据源过滤"),
|
||||
data_type: Optional[str] = Query(None, description="数据类型过滤"),
|
||||
country: Optional[str] = Query(None, description="国家过滤"),
|
||||
@@ -294,11 +377,12 @@ async def export_json(
|
||||
"""导出数据为 JSON 格式"""
|
||||
|
||||
where_sql, params = build_where_clause(source, data_type, country, search)
|
||||
if mode != "history":
|
||||
where_sql = f"({where_sql}) AND COALESCE(is_current, TRUE) = TRUE"
|
||||
params["limit"] = limit
|
||||
|
||||
query = text(f"""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
FROM collected_data
|
||||
WHERE {where_sql}
|
||||
@@ -311,27 +395,7 @@ async def export_json(
|
||||
|
||||
data = []
|
||||
for row in rows:
|
||||
data.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"source": row[1],
|
||||
"source_id": row[2],
|
||||
"data_type": row[3],
|
||||
"name": row[4],
|
||||
"title": row[5],
|
||||
"description": row[6],
|
||||
"country": row[7],
|
||||
"city": row[8],
|
||||
"latitude": row[9],
|
||||
"longitude": row[10],
|
||||
"value": row[11],
|
||||
"unit": row[12],
|
||||
"metadata": row[13],
|
||||
"collected_at": row[14].isoformat() if row[14] else None,
|
||||
"reference_date": row[15].isoformat() if row[15] else None,
|
||||
"is_valid": row[16],
|
||||
}
|
||||
)
|
||||
data.append(serialize_collected_row(row))
|
||||
|
||||
json_str = json.dumps({"data": data, "total": len(data)}, ensure_ascii=False, indent=2)
|
||||
|
||||
@@ -346,6 +410,7 @@ async def export_json(
|
||||
|
||||
@router.get("/export/csv")
|
||||
async def export_csv(
|
||||
mode: str = Query("current", description="查询模式: current/history"),
|
||||
source: Optional[str] = Query(None, description="数据源过滤"),
|
||||
data_type: Optional[str] = Query(None, description="数据类型过滤"),
|
||||
country: Optional[str] = Query(None, description="国家过滤"),
|
||||
@@ -357,11 +422,12 @@ async def export_csv(
|
||||
"""导出数据为 CSV 格式"""
|
||||
|
||||
where_sql, params = build_where_clause(source, data_type, country, search)
|
||||
if mode != "history":
|
||||
where_sql = f"({where_sql}) AND COALESCE(is_current, TRUE) = TRUE"
|
||||
params["limit"] = limit
|
||||
|
||||
query = text(f"""
|
||||
SELECT id, source, source_id, data_type, name, title, description,
|
||||
country, city, latitude, longitude, value, unit,
|
||||
metadata, collected_at, reference_date, is_valid
|
||||
FROM collected_data
|
||||
WHERE {where_sql}
|
||||
@@ -409,16 +475,16 @@ async def export_csv(
|
||||
row[4],
|
||||
row[5],
|
||||
row[6],
|
||||
row[7],
|
||||
row[8],
|
||||
row[9],
|
||||
get_metadata_field(row[7], "country"),
|
||||
get_metadata_field(row[7], "city"),
|
||||
get_metadata_field(row[7], "latitude"),
|
||||
get_metadata_field(row[7], "longitude"),
|
||||
get_metadata_field(row[7], "value"),
|
||||
get_metadata_field(row[7], "unit"),
|
||||
json.dumps(row[7]) if row[7] else "",
|
||||
row[8].isoformat() if row[8] else "",
|
||||
row[9].isoformat() if row[9] else "",
|
||||
row[10],
|
||||
row[11],
|
||||
row[12],
|
||||
json.dumps(row[13]) if row[13] else "",
|
||||
row[14].isoformat() if row[14] else "",
|
||||
row[15].isoformat() if row[15] else "",
|
||||
row[16],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.data_sources import get_data_sources_config
|
||||
from app.db.session import get_db
|
||||
from app.models.collected_data import CollectedData
|
||||
from app.models.datasource import DataSource
|
||||
from app.models.task import CollectionTask
|
||||
from app.models.user import User
|
||||
from app.services.scheduler import run_collector_now, sync_datasource_job
|
||||
from app.services.scheduler import get_latest_task_id_for_datasource, run_collector_now, sync_datasource_job
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -83,9 +84,11 @@ async def list_datasources(
|
||||
datasources = result.scalars().all()
|
||||
|
||||
collector_list = []
|
||||
config = get_data_sources_config()
|
||||
for datasource in datasources:
|
||||
running_task = await get_running_task(db, datasource.id)
|
||||
last_task = await get_last_completed_task(db, datasource.id)
|
||||
endpoint = await config.get_url(datasource.source, db)
|
||||
data_count_result = await db.execute(
|
||||
select(func.count(CollectedData.id)).where(CollectedData.source == datasource.source)
|
||||
)
|
||||
@@ -105,10 +108,12 @@ async def list_datasources(
|
||||
"frequency_minutes": datasource.frequency_minutes,
|
||||
"is_active": datasource.is_active,
|
||||
"collector_class": datasource.collector_class,
|
||||
"endpoint": endpoint,
|
||||
"last_run": last_run,
|
||||
"is_running": running_task is not None,
|
||||
"task_id": running_task.id if running_task else None,
|
||||
"progress": running_task.progress if running_task else None,
|
||||
"phase": running_task.phase if running_task else None,
|
||||
"records_processed": running_task.records_processed if running_task else None,
|
||||
"total_records": running_task.total_records if running_task else None,
|
||||
}
|
||||
@@ -127,6 +132,9 @@ async def get_datasource(
|
||||
if not datasource:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
config = get_data_sources_config()
|
||||
endpoint = await config.get_url(datasource.source, db)
|
||||
|
||||
return {
|
||||
"id": datasource.id,
|
||||
"name": datasource.name,
|
||||
@@ -136,6 +144,7 @@ async def get_datasource(
|
||||
"frequency_minutes": datasource.frequency_minutes,
|
||||
"collector_class": datasource.collector_class,
|
||||
"source": datasource.source,
|
||||
"endpoint": endpoint,
|
||||
"is_active": datasource.is_active,
|
||||
}
|
||||
|
||||
@@ -212,9 +221,16 @@ async def trigger_datasource(
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to trigger collector '{datasource.source}'")
|
||||
|
||||
task_id = None
|
||||
for _ in range(10):
|
||||
task_id = await get_latest_task_id_for_datasource(datasource.id)
|
||||
if task_id is not None:
|
||||
break
|
||||
|
||||
return {
|
||||
"status": "triggered",
|
||||
"source_id": datasource.id,
|
||||
"task_id": task_id,
|
||||
"collector_name": datasource.source,
|
||||
"message": f"Collector '{datasource.source}' has been triggered",
|
||||
}
|
||||
@@ -252,21 +268,29 @@ async def clear_datasource_data(
|
||||
@router.get("/{source_id}/task-status")
|
||||
async def get_task_status(
|
||||
source_id: str,
|
||||
task_id: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
datasource = await get_datasource_record(db, source_id)
|
||||
if not datasource:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
running_task = await get_running_task(db, datasource.id)
|
||||
if not running_task:
|
||||
return {"is_running": False, "task_id": None, "progress": None}
|
||||
if task_id is not None:
|
||||
task = await db.get(CollectionTask, task_id)
|
||||
if not task or task.datasource_id != datasource.id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
else:
|
||||
task = await get_running_task(db, datasource.id)
|
||||
|
||||
if not task:
|
||||
return {"is_running": False, "task_id": None, "progress": None, "phase": None, "status": "idle"}
|
||||
|
||||
return {
|
||||
"is_running": True,
|
||||
"task_id": running_task.id,
|
||||
"progress": running_task.progress,
|
||||
"records_processed": running_task.records_processed,
|
||||
"total_records": running_task.total_records,
|
||||
"status": running_task.status,
|
||||
}
|
||||
"is_running": task.status == "running",
|
||||
"task_id": task.id,
|
||||
"progress": task.progress,
|
||||
"phase": task.phase,
|
||||
"records_processed": task.records_processed,
|
||||
"total_records": task.total_records,
|
||||
"status": task.status,
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.collected_data_fields import get_record_field
|
||||
from app.db.session import get_db
|
||||
from app.models.collected_data import CollectedData
|
||||
from app.services.cable_graph import build_graph_from_data, CableGraph
|
||||
@@ -83,9 +84,9 @@ def convert_cable_to_geojson(records: List[CollectedData]) -> Dict[str, Any]:
|
||||
"rfs": metadata.get("rfs"),
|
||||
"RFS": metadata.get("rfs"),
|
||||
"status": metadata.get("status", "active"),
|
||||
"length": record.value,
|
||||
"length_km": record.value,
|
||||
"SHAPE__Length": record.value,
|
||||
"length": get_record_field(record, "value"),
|
||||
"length_km": get_record_field(record, "value"),
|
||||
"SHAPE__Length": get_record_field(record, "value"),
|
||||
"url": metadata.get("url"),
|
||||
"color": metadata.get("color"),
|
||||
"year": metadata.get("year"),
|
||||
@@ -101,8 +102,10 @@ def convert_landing_point_to_geojson(records: List[CollectedData], city_to_cable
|
||||
|
||||
for record in records:
|
||||
try:
|
||||
lat = float(record.latitude) if record.latitude else None
|
||||
lon = float(record.longitude) if record.longitude else None
|
||||
latitude = get_record_field(record, "latitude")
|
||||
longitude = get_record_field(record, "longitude")
|
||||
lat = float(latitude) if latitude else None
|
||||
lon = float(longitude) if longitude else None
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
@@ -116,8 +119,8 @@ def convert_landing_point_to_geojson(records: List[CollectedData], city_to_cable
|
||||
"id": record.id,
|
||||
"source_id": record.source_id,
|
||||
"name": record.name,
|
||||
"country": record.country,
|
||||
"city": record.city,
|
||||
"country": get_record_field(record, "country"),
|
||||
"city": get_record_field(record, "city"),
|
||||
"is_tbd": metadata.get("is_tbd", False),
|
||||
}
|
||||
|
||||
@@ -185,9 +188,11 @@ def convert_supercomputer_to_geojson(records: List[CollectedData]) -> Dict[str,
|
||||
|
||||
for record in records:
|
||||
try:
|
||||
lat = float(record.latitude) if record.latitude and record.latitude != "0.0" else None
|
||||
latitude = get_record_field(record, "latitude")
|
||||
longitude = get_record_field(record, "longitude")
|
||||
lat = float(latitude) if latitude and latitude != "0.0" else None
|
||||
lon = (
|
||||
float(record.longitude) if record.longitude and record.longitude != "0.0" else None
|
||||
float(longitude) if longitude and longitude != "0.0" else None
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
lat, lon = None, None
|
||||
@@ -203,12 +208,12 @@ def convert_supercomputer_to_geojson(records: List[CollectedData]) -> Dict[str,
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"rank": metadata.get("rank"),
|
||||
"r_max": record.value,
|
||||
"r_peak": metadata.get("r_peak"),
|
||||
"cores": metadata.get("cores"),
|
||||
"power": metadata.get("power"),
|
||||
"country": record.country,
|
||||
"city": record.city,
|
||||
"r_max": get_record_field(record, "rmax"),
|
||||
"r_peak": get_record_field(record, "rpeak"),
|
||||
"cores": get_record_field(record, "cores"),
|
||||
"power": get_record_field(record, "power"),
|
||||
"country": get_record_field(record, "country"),
|
||||
"city": get_record_field(record, "city"),
|
||||
"data_type": "supercomputer",
|
||||
},
|
||||
}
|
||||
@@ -223,8 +228,10 @@ def convert_gpu_cluster_to_geojson(records: List[CollectedData]) -> Dict[str, An
|
||||
|
||||
for record in records:
|
||||
try:
|
||||
lat = float(record.latitude) if record.latitude else None
|
||||
lon = float(record.longitude) if record.longitude else None
|
||||
latitude = get_record_field(record, "latitude")
|
||||
longitude = get_record_field(record, "longitude")
|
||||
lat = float(latitude) if latitude else None
|
||||
lon = float(longitude) if longitude else None
|
||||
except (ValueError, TypeError):
|
||||
lat, lon = None, None
|
||||
|
||||
@@ -238,8 +245,8 @@ def convert_gpu_cluster_to_geojson(records: List[CollectedData]) -> Dict[str, An
|
||||
"properties": {
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"country": record.country,
|
||||
"city": record.city,
|
||||
"country": get_record_field(record, "country"),
|
||||
"city": get_record_field(record, "city"),
|
||||
"metadata": metadata,
|
||||
"data_type": "gpu_cluster",
|
||||
},
|
||||
|
||||
62
backend/app/core/collected_data_fields.py
Normal file
62
backend/app/core/collected_data_fields.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
FIELD_ALIASES = {
|
||||
"country": ("country",),
|
||||
"city": ("city",),
|
||||
"latitude": ("latitude",),
|
||||
"longitude": ("longitude",),
|
||||
"value": ("value",),
|
||||
"unit": ("unit",),
|
||||
"cores": ("cores",),
|
||||
"rmax": ("rmax", "r_max"),
|
||||
"rpeak": ("rpeak", "r_peak"),
|
||||
"power": ("power",),
|
||||
}
|
||||
|
||||
|
||||
def get_metadata_field(metadata: Optional[Dict[str, Any]], field: str, fallback: Any = None) -> Any:
|
||||
if isinstance(metadata, dict):
|
||||
for key in FIELD_ALIASES.get(field, (field,)):
|
||||
value = metadata.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return fallback
|
||||
|
||||
|
||||
def build_dynamic_metadata(
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
*,
|
||||
country: Any = None,
|
||||
city: Any = None,
|
||||
latitude: Any = None,
|
||||
longitude: Any = None,
|
||||
value: Any = None,
|
||||
unit: Any = None,
|
||||
) -> Dict[str, Any]:
|
||||
merged = dict(metadata) if isinstance(metadata, dict) else {}
|
||||
|
||||
fallbacks = {
|
||||
"country": country,
|
||||
"city": city,
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"value": value,
|
||||
"unit": unit,
|
||||
}
|
||||
|
||||
for field, fallback in fallbacks.items():
|
||||
if fallback not in (None, "") and get_metadata_field(merged, field) in (None, ""):
|
||||
merged[field] = fallback
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def get_record_field(record: Any, field: str) -> Any:
|
||||
metadata = getattr(record, "extra_data", None) or {}
|
||||
fallback_attr = field
|
||||
if field in {"cores", "rmax", "rpeak", "power"}:
|
||||
fallback = None
|
||||
else:
|
||||
fallback = getattr(record, fallback_attr, None)
|
||||
return get_metadata_field(metadata, field, fallback=fallback)
|
||||
280
backend/app/core/countries.py
Normal file
280
backend/app/core/countries.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
COUNTRY_ENTRIES = [
|
||||
("阿富汗", ["Afghanistan", "AF", "AFG"]),
|
||||
("阿尔巴尼亚", ["Albania", "AL", "ALB"]),
|
||||
("阿尔及利亚", ["Algeria", "DZ", "DZA"]),
|
||||
("安道尔", ["Andorra", "AD", "AND"]),
|
||||
("安哥拉", ["Angola", "AO", "AGO"]),
|
||||
("安提瓜和巴布达", ["Antigua and Barbuda", "AG", "ATG"]),
|
||||
("阿根廷", ["Argentina", "AR", "ARG"]),
|
||||
("亚美尼亚", ["Armenia", "AM", "ARM"]),
|
||||
("澳大利亚", ["Australia", "AU", "AUS"]),
|
||||
("奥地利", ["Austria", "AT", "AUT"]),
|
||||
("阿塞拜疆", ["Azerbaijan", "AZ", "AZE"]),
|
||||
("巴哈马", ["Bahamas", "BS", "BHS"]),
|
||||
("巴林", ["Bahrain", "BH", "BHR"]),
|
||||
("孟加拉国", ["Bangladesh", "BD", "BGD"]),
|
||||
("巴巴多斯", ["Barbados", "BB", "BRB"]),
|
||||
("白俄罗斯", ["Belarus", "BY", "BLR"]),
|
||||
("比利时", ["Belgium", "BE", "BEL"]),
|
||||
("伯利兹", ["Belize", "BZ", "BLZ"]),
|
||||
("贝宁", ["Benin", "BJ", "BEN"]),
|
||||
("不丹", ["Bhutan", "BT", "BTN"]),
|
||||
("玻利维亚", ["Bolivia", "BO", "BOL", "Bolivia (Plurinational State of)"]),
|
||||
("波斯尼亚和黑塞哥维那", ["Bosnia and Herzegovina", "BA", "BIH"]),
|
||||
("博茨瓦纳", ["Botswana", "BW", "BWA"]),
|
||||
("巴西", ["Brazil", "BR", "BRA"]),
|
||||
("文莱", ["Brunei", "BN", "BRN", "Brunei Darussalam"]),
|
||||
("保加利亚", ["Bulgaria", "BG", "BGR"]),
|
||||
("布基纳法索", ["Burkina Faso", "BF", "BFA"]),
|
||||
("布隆迪", ["Burundi", "BI", "BDI"]),
|
||||
("柬埔寨", ["Cambodia", "KH", "KHM"]),
|
||||
("喀麦隆", ["Cameroon", "CM", "CMR"]),
|
||||
("加拿大", ["Canada", "CA", "CAN"]),
|
||||
("佛得角", ["Cape Verde", "CV", "CPV", "Cabo Verde"]),
|
||||
("中非", ["Central African Republic", "CF", "CAF"]),
|
||||
("乍得", ["Chad", "TD", "TCD"]),
|
||||
("智利", ["Chile", "CL", "CHL"]),
|
||||
("中国", ["China", "CN", "CHN", "Mainland China", "PRC", "People's Republic of China"]),
|
||||
("中国(香港)", ["Hong Kong", "HK", "HKG", "Hong Kong SAR", "China Hong Kong", "Hong Kong, China"]),
|
||||
("中国(澳门)", ["Macao", "Macau", "MO", "MAC", "Macao SAR", "China Macao", "Macau, China"]),
|
||||
("中国(台湾)", ["Taiwan", "TW", "TWN", "Chinese Taipei", "Taiwan, China"]),
|
||||
("哥伦比亚", ["Colombia", "CO", "COL"]),
|
||||
("科摩罗", ["Comoros", "KM", "COM"]),
|
||||
("刚果(布)", ["Republic of the Congo", "Congo", "Congo-Brazzaville", "CG", "COG"]),
|
||||
("刚果(金)", ["Democratic Republic of the Congo", "DR Congo", "Congo-Kinshasa", "CD", "COD"]),
|
||||
("哥斯达黎加", ["Costa Rica", "CR", "CRI"]),
|
||||
("科特迪瓦", ["Cote d'Ivoire", "Côte d'Ivoire", "Ivory Coast", "CI", "CIV"]),
|
||||
("克罗地亚", ["Croatia", "HR", "HRV"]),
|
||||
("古巴", ["Cuba", "CU", "CUB"]),
|
||||
("塞浦路斯", ["Cyprus", "CY", "CYP"]),
|
||||
("捷克", ["Czech Republic", "Czechia", "CZ", "CZE"]),
|
||||
("丹麦", ["Denmark", "DK", "DNK"]),
|
||||
("吉布提", ["Djibouti", "DJ", "DJI"]),
|
||||
("多米尼克", ["Dominica", "DM", "DMA"]),
|
||||
("多米尼加", ["Dominican Republic", "DO", "DOM"]),
|
||||
("厄瓜多尔", ["Ecuador", "EC", "ECU"]),
|
||||
("埃及", ["Egypt", "EG", "EGY"]),
|
||||
("萨尔瓦多", ["El Salvador", "SV", "SLV"]),
|
||||
("赤道几内亚", ["Equatorial Guinea", "GQ", "GNQ"]),
|
||||
("厄立特里亚", ["Eritrea", "ER", "ERI"]),
|
||||
("爱沙尼亚", ["Estonia", "EE", "EST"]),
|
||||
("埃斯瓦蒂尼", ["Eswatini", "SZ", "SWZ", "Swaziland"]),
|
||||
("埃塞俄比亚", ["Ethiopia", "ET", "ETH"]),
|
||||
("斐济", ["Fiji", "FJ", "FJI"]),
|
||||
("芬兰", ["Finland", "FI", "FIN"]),
|
||||
("法国", ["France", "FR", "FRA"]),
|
||||
("加蓬", ["Gabon", "GA", "GAB"]),
|
||||
("冈比亚", ["Gambia", "GM", "GMB"]),
|
||||
("格鲁吉亚", ["Georgia", "GE", "GEO"]),
|
||||
("德国", ["Germany", "DE", "DEU"]),
|
||||
("加纳", ["Ghana", "GH", "GHA"]),
|
||||
("希腊", ["Greece", "GR", "GRC"]),
|
||||
("格林纳达", ["Grenada", "GD", "GRD"]),
|
||||
("危地马拉", ["Guatemala", "GT", "GTM"]),
|
||||
("几内亚", ["Guinea", "GN", "GIN"]),
|
||||
("几内亚比绍", ["Guinea-Bissau", "GW", "GNB"]),
|
||||
("圭亚那", ["Guyana", "GY", "GUY"]),
|
||||
("海地", ["Haiti", "HT", "HTI"]),
|
||||
("洪都拉斯", ["Honduras", "HN", "HND"]),
|
||||
("匈牙利", ["Hungary", "HU", "HUN"]),
|
||||
("冰岛", ["Iceland", "IS", "ISL"]),
|
||||
("印度", ["India", "IN", "IND"]),
|
||||
("印度尼西亚", ["Indonesia", "ID", "IDN"]),
|
||||
("伊朗", ["Iran", "IR", "IRN", "Iran (Islamic Republic of)"]),
|
||||
("伊拉克", ["Iraq", "IQ", "IRQ"]),
|
||||
("爱尔兰", ["Ireland", "IE", "IRL"]),
|
||||
("以色列", ["Israel", "IL", "ISR"]),
|
||||
("意大利", ["Italy", "IT", "ITA"]),
|
||||
("牙买加", ["Jamaica", "JM", "JAM"]),
|
||||
("日本", ["Japan", "JP", "JPN"]),
|
||||
("约旦", ["Jordan", "JO", "JOR"]),
|
||||
("哈萨克斯坦", ["Kazakhstan", "KZ", "KAZ"]),
|
||||
("肯尼亚", ["Kenya", "KE", "KEN"]),
|
||||
("基里巴斯", ["Kiribati", "KI", "KIR"]),
|
||||
("朝鲜", ["North Korea", "Korea, DPRK", "Democratic People's Republic of Korea", "KP", "PRK"]),
|
||||
("韩国", ["South Korea", "Republic of Korea", "Korea", "KR", "KOR"]),
|
||||
("科威特", ["Kuwait", "KW", "KWT"]),
|
||||
("吉尔吉斯斯坦", ["Kyrgyzstan", "KG", "KGZ"]),
|
||||
("老挝", ["Laos", "Lao PDR", "Lao People's Democratic Republic", "LA", "LAO"]),
|
||||
("拉脱维亚", ["Latvia", "LV", "LVA"]),
|
||||
("黎巴嫩", ["Lebanon", "LB", "LBN"]),
|
||||
("莱索托", ["Lesotho", "LS", "LSO"]),
|
||||
("利比里亚", ["Liberia", "LR", "LBR"]),
|
||||
("利比亚", ["Libya", "LY", "LBY"]),
|
||||
("列支敦士登", ["Liechtenstein", "LI", "LIE"]),
|
||||
("立陶宛", ["Lithuania", "LT", "LTU"]),
|
||||
("卢森堡", ["Luxembourg", "LU", "LUX"]),
|
||||
("马达加斯加", ["Madagascar", "MG", "MDG"]),
|
||||
("马拉维", ["Malawi", "MW", "MWI"]),
|
||||
("马来西亚", ["Malaysia", "MY", "MYS"]),
|
||||
("马尔代夫", ["Maldives", "MV", "MDV"]),
|
||||
("马里", ["Mali", "ML", "MLI"]),
|
||||
("马耳他", ["Malta", "MT", "MLT"]),
|
||||
("马绍尔群岛", ["Marshall Islands", "MH", "MHL"]),
|
||||
("毛里塔尼亚", ["Mauritania", "MR", "MRT"]),
|
||||
("毛里求斯", ["Mauritius", "MU", "MUS"]),
|
||||
("墨西哥", ["Mexico", "MX", "MEX"]),
|
||||
("密克罗尼西亚", ["Micronesia", "FM", "FSM", "Federated States of Micronesia"]),
|
||||
("摩尔多瓦", ["Moldova", "MD", "MDA", "Republic of Moldova"]),
|
||||
("摩纳哥", ["Monaco", "MC", "MCO"]),
|
||||
("蒙古", ["Mongolia", "MN", "MNG"]),
|
||||
("黑山", ["Montenegro", "ME", "MNE"]),
|
||||
("摩洛哥", ["Morocco", "MA", "MAR"]),
|
||||
("莫桑比克", ["Mozambique", "MZ", "MOZ"]),
|
||||
("缅甸", ["Myanmar", "MM", "MMR", "Burma"]),
|
||||
("纳米比亚", ["Namibia", "NA", "NAM"]),
|
||||
("瑙鲁", ["Nauru", "NR", "NRU"]),
|
||||
("尼泊尔", ["Nepal", "NP", "NPL"]),
|
||||
("荷兰", ["Netherlands", "NL", "NLD"]),
|
||||
("新西兰", ["New Zealand", "NZ", "NZL"]),
|
||||
("尼加拉瓜", ["Nicaragua", "NI", "NIC"]),
|
||||
("尼日尔", ["Niger", "NE", "NER"]),
|
||||
("尼日利亚", ["Nigeria", "NG", "NGA"]),
|
||||
("北马其顿", ["North Macedonia", "MK", "MKD", "Macedonia"]),
|
||||
("挪威", ["Norway", "NO", "NOR"]),
|
||||
("阿曼", ["Oman", "OM", "OMN"]),
|
||||
("巴基斯坦", ["Pakistan", "PK", "PAK"]),
|
||||
("帕劳", ["Palau", "PW", "PLW"]),
|
||||
("巴勒斯坦", ["Palestine", "PS", "PSE", "State of Palestine"]),
|
||||
("巴拿马", ["Panama", "PA", "PAN"]),
|
||||
("巴布亚新几内亚", ["Papua New Guinea", "PG", "PNG"]),
|
||||
("巴拉圭", ["Paraguay", "PY", "PRY"]),
|
||||
("秘鲁", ["Peru", "PE", "PER"]),
|
||||
("菲律宾", ["Philippines", "PH", "PHL"]),
|
||||
("波兰", ["Poland", "PL", "POL"]),
|
||||
("葡萄牙", ["Portugal", "PT", "PRT"]),
|
||||
("卡塔尔", ["Qatar", "QA", "QAT"]),
|
||||
("罗马尼亚", ["Romania", "RO", "ROU"]),
|
||||
("俄罗斯", ["Russia", "Russian Federation", "RU", "RUS"]),
|
||||
("卢旺达", ["Rwanda", "RW", "RWA"]),
|
||||
("圣基茨和尼维斯", ["Saint Kitts and Nevis", "KN", "KNA"]),
|
||||
("圣卢西亚", ["Saint Lucia", "LC", "LCA"]),
|
||||
("圣文森特和格林纳丁斯", ["Saint Vincent and the Grenadines", "VC", "VCT"]),
|
||||
("萨摩亚", ["Samoa", "WS", "WSM"]),
|
||||
("圣马力诺", ["San Marino", "SM", "SMR"]),
|
||||
("圣多美和普林西比", ["Sao Tome and Principe", "ST", "STP", "São Tomé and Príncipe"]),
|
||||
("沙特阿拉伯", ["Saudi Arabia", "SA", "SAU"]),
|
||||
("塞内加尔", ["Senegal", "SN", "SEN"]),
|
||||
("塞尔维亚", ["Serbia", "RS", "SRB", "Kosovo", "XK", "XKS", "Republic of Kosovo"]),
|
||||
("塞舌尔", ["Seychelles", "SC", "SYC"]),
|
||||
("塞拉利昂", ["Sierra Leone", "SL", "SLE"]),
|
||||
("新加坡", ["Singapore", "SG", "SGP"]),
|
||||
("斯洛伐克", ["Slovakia", "SK", "SVK"]),
|
||||
("斯洛文尼亚", ["Slovenia", "SI", "SVN"]),
|
||||
("所罗门群岛", ["Solomon Islands", "SB", "SLB"]),
|
||||
("索马里", ["Somalia", "SO", "SOM"]),
|
||||
("南非", ["South Africa", "ZA", "ZAF"]),
|
||||
("南苏丹", ["South Sudan", "SS", "SSD"]),
|
||||
("西班牙", ["Spain", "ES", "ESP"]),
|
||||
("斯里兰卡", ["Sri Lanka", "LK", "LKA"]),
|
||||
("苏丹", ["Sudan", "SD", "SDN"]),
|
||||
("苏里南", ["Suriname", "SR", "SUR"]),
|
||||
("瑞典", ["Sweden", "SE", "SWE"]),
|
||||
("瑞士", ["Switzerland", "CH", "CHE"]),
|
||||
("叙利亚", ["Syria", "SY", "SYR", "Syrian Arab Republic"]),
|
||||
("塔吉克斯坦", ["Tajikistan", "TJ", "TJK"]),
|
||||
("坦桑尼亚", ["Tanzania", "TZ", "TZA", "United Republic of Tanzania"]),
|
||||
("泰国", ["Thailand", "TH", "THA"]),
|
||||
("东帝汶", ["Timor-Leste", "East Timor", "TL", "TLS"]),
|
||||
("多哥", ["Togo", "TG", "TGO"]),
|
||||
("汤加", ["Tonga", "TO", "TON"]),
|
||||
("特立尼达和多巴哥", ["Trinidad and Tobago", "TT", "TTO"]),
|
||||
("突尼斯", ["Tunisia", "TN", "TUN"]),
|
||||
("土耳其", ["Turkey", "TR", "TUR", "Türkiye"]),
|
||||
("土库曼斯坦", ["Turkmenistan", "TM", "TKM"]),
|
||||
("图瓦卢", ["Tuvalu", "TV", "TUV"]),
|
||||
("乌干达", ["Uganda", "UG", "UGA"]),
|
||||
("乌克兰", ["Ukraine", "UA", "UKR"]),
|
||||
("阿联酋", ["United Arab Emirates", "AE", "ARE", "UAE"]),
|
||||
("英国", ["United Kingdom", "UK", "GB", "GBR", "Great Britain", "Britain", "England"]),
|
||||
("美国", ["United States", "United States of America", "US", "USA", "U.S.", "U.S.A."]),
|
||||
("乌拉圭", ["Uruguay", "UY", "URY"]),
|
||||
("乌兹别克斯坦", ["Uzbekistan", "UZ", "UZB"]),
|
||||
("瓦努阿图", ["Vanuatu", "VU", "VUT"]),
|
||||
("梵蒂冈", ["Vatican City", "Holy See", "VA", "VAT"]),
|
||||
("委内瑞拉", ["Venezuela", "VE", "VEN", "Venezuela (Bolivarian Republic of)"]),
|
||||
("越南", ["Vietnam", "Viet Nam", "VN", "VNM"]),
|
||||
("也门", ["Yemen", "YE", "YEM"]),
|
||||
("赞比亚", ["Zambia", "ZM", "ZMB"]),
|
||||
("津巴布韦", ["Zimbabwe", "ZW", "ZWE"]),
|
||||
]
|
||||
|
||||
|
||||
COUNTRY_OPTIONS = [entry[0] for entry in COUNTRY_ENTRIES]
|
||||
CANONICAL_COUNTRY_SET = set(COUNTRY_OPTIONS)
|
||||
INVALID_COUNTRY_VALUES = {
|
||||
"",
|
||||
"-",
|
||||
"--",
|
||||
"unknown",
|
||||
"n/a",
|
||||
"na",
|
||||
"none",
|
||||
"null",
|
||||
"global",
|
||||
"world",
|
||||
"worldwide",
|
||||
"xx",
|
||||
}
|
||||
NUMERIC_LIKE_PATTERN = re.compile(r"^[\d\s,._%+\-]+$")
|
||||
|
||||
COUNTRY_ALIAS_MAP = {}
|
||||
COUNTRY_VARIANTS_MAP = {}
|
||||
for canonical, aliases in COUNTRY_ENTRIES:
|
||||
COUNTRY_ALIAS_MAP[canonical.casefold()] = canonical
|
||||
variants = [canonical, *aliases]
|
||||
COUNTRY_VARIANTS_MAP[canonical] = variants
|
||||
for alias in aliases:
|
||||
COUNTRY_ALIAS_MAP[alias.casefold()] = canonical
|
||||
|
||||
|
||||
def normalize_country(value: Any) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
|
||||
normalized = re.sub(r"\s+", " ", value.strip())
|
||||
normalized = normalized.replace("(", "(").replace(")", ")")
|
||||
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
lowered = normalized.casefold()
|
||||
if lowered in INVALID_COUNTRY_VALUES:
|
||||
return None
|
||||
|
||||
if NUMERIC_LIKE_PATTERN.fullmatch(normalized):
|
||||
return None
|
||||
|
||||
if normalized in CANONICAL_COUNTRY_SET:
|
||||
return normalized
|
||||
|
||||
return COUNTRY_ALIAS_MAP.get(lowered)
|
||||
|
||||
|
||||
def get_country_search_variants(value: Any) -> list[str]:
|
||||
canonical = normalize_country(value)
|
||||
if canonical is None:
|
||||
return []
|
||||
|
||||
variants = []
|
||||
seen = set()
|
||||
for item in COUNTRY_VARIANTS_MAP.get(canonical, [canonical]):
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
normalized = re.sub(r"\s+", " ", item.strip())
|
||||
if not normalized:
|
||||
continue
|
||||
key = normalized.casefold()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
variants.append(normalized)
|
||||
|
||||
return variants
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
@@ -63,6 +64,7 @@ async def init_db():
|
||||
import app.models.user # noqa: F401
|
||||
import app.models.gpu_cluster # noqa: F401
|
||||
import app.models.task # noqa: F401
|
||||
import app.models.data_snapshot # noqa: F401
|
||||
import app.models.datasource # noqa: F401
|
||||
import app.models.datasource_config # noqa: F401
|
||||
import app.models.alert # noqa: F401
|
||||
@@ -71,6 +73,55 @@ async def init_db():
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE collected_data
|
||||
ADD COLUMN IF NOT EXISTS snapshot_id INTEGER,
|
||||
ADD COLUMN IF NOT EXISTS task_id INTEGER,
|
||||
ADD COLUMN IF NOT EXISTS entity_key VARCHAR(255),
|
||||
ADD COLUMN IF NOT EXISTS is_current BOOLEAN DEFAULT TRUE,
|
||||
ADD COLUMN IF NOT EXISTS previous_record_id INTEGER,
|
||||
ADD COLUMN IF NOT EXISTS change_type VARCHAR(20),
|
||||
ADD COLUMN IF NOT EXISTS change_summary JSONB DEFAULT '{}'::jsonb,
|
||||
ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ
|
||||
"""
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE collection_tasks
|
||||
ADD COLUMN IF NOT EXISTS phase VARCHAR(30) DEFAULT 'queued'
|
||||
"""
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_collected_data_source_source_id
|
||||
ON collected_data (source, source_id)
|
||||
"""
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE collected_data
|
||||
SET entity_key = source || ':' || COALESCE(source_id, id::text)
|
||||
WHERE entity_key IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE collected_data
|
||||
SET is_current = TRUE
|
||||
WHERE is_current IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
await seed_default_datasources(session)
|
||||
await seed_default_datasources(session)
|
||||
|
||||
@@ -9,7 +9,12 @@ from app.api.v1 import websocket
|
||||
from app.core.config import settings
|
||||
from app.core.websocket.broadcaster import broadcaster
|
||||
from app.db.session import init_db
|
||||
from app.services.scheduler import start_scheduler, stop_scheduler, sync_scheduler_with_datasources
|
||||
from app.services.scheduler import (
|
||||
cleanup_stale_running_tasks,
|
||||
start_scheduler,
|
||||
stop_scheduler,
|
||||
sync_scheduler_with_datasources,
|
||||
)
|
||||
|
||||
|
||||
class WebSocketCORSMiddleware(BaseHTTPMiddleware):
|
||||
@@ -26,6 +31,7 @@ class WebSocketCORSMiddleware(BaseHTTPMiddleware):
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await init_db()
|
||||
await cleanup_stale_running_tasks()
|
||||
start_scheduler()
|
||||
await sync_scheduler_with_datasources()
|
||||
broadcaster.start()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from app.models.user import User
|
||||
from app.models.gpu_cluster import GPUCluster
|
||||
from app.models.task import CollectionTask
|
||||
from app.models.data_snapshot import DataSnapshot
|
||||
from app.models.datasource import DataSource
|
||||
from app.models.datasource_config import DataSourceConfig
|
||||
from app.models.alert import Alert, AlertSeverity, AlertStatus
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
"User",
|
||||
"GPUCluster",
|
||||
"CollectionTask",
|
||||
"DataSnapshot",
|
||||
"DataSource",
|
||||
"DataSourceConfig",
|
||||
"SystemSetting",
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Collected Data model for storing data from all collectors"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text, JSON, Index
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, JSON, Index
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.core.collected_data_fields import get_record_field
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
@@ -12,8 +13,11 @@ class CollectedData(Base):
|
||||
__tablename__ = "collected_data"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
snapshot_id = Column(Integer, ForeignKey("data_snapshots.id"), nullable=True, index=True)
|
||||
task_id = Column(Integer, ForeignKey("collection_tasks.id"), nullable=True, index=True)
|
||||
source = Column(String(100), nullable=False, index=True) # e.g., "top500", "huggingface_models"
|
||||
source_id = Column(String(100), index=True) # Original ID from source, e.g., "rank_1"
|
||||
entity_key = Column(String(255), index=True)
|
||||
data_type = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # e.g., "supercomputer", "model", "dataset"
|
||||
@@ -23,16 +27,6 @@ class CollectedData(Base):
|
||||
title = Column(String(500))
|
||||
description = Column(Text)
|
||||
|
||||
# Location data (for geo visualization)
|
||||
country = Column(String(100))
|
||||
city = Column(String(100))
|
||||
latitude = Column(String(50))
|
||||
longitude = Column(String(50))
|
||||
|
||||
# Performance metrics
|
||||
value = Column(String(100)) # Generic value field (Rmax, Rpeak, etc.)
|
||||
unit = Column(String(20))
|
||||
|
||||
# Additional metadata as JSON
|
||||
extra_data = Column(
|
||||
"metadata", JSON, default={}
|
||||
@@ -44,11 +38,17 @@ class CollectedData(Base):
|
||||
|
||||
# Status
|
||||
is_valid = Column(Integer, default=1) # 1=valid, 0=invalid
|
||||
is_current = Column(Boolean, default=True, index=True)
|
||||
previous_record_id = Column(Integer, ForeignKey("collected_data.id"), nullable=True, index=True)
|
||||
change_type = Column(String(20), nullable=True)
|
||||
change_summary = Column(JSON, default={})
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
Index("idx_collected_data_source_collected", "source", "collected_at"),
|
||||
Index("idx_collected_data_source_type", "source", "data_type"),
|
||||
Index("idx_collected_data_source_source_id", "source", "source_id"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -58,18 +58,21 @@ class CollectedData(Base):
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"snapshot_id": self.snapshot_id,
|
||||
"task_id": self.task_id,
|
||||
"source": self.source,
|
||||
"source_id": self.source_id,
|
||||
"entity_key": self.entity_key,
|
||||
"data_type": self.data_type,
|
||||
"name": self.name,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"country": self.country,
|
||||
"city": self.city,
|
||||
"latitude": self.latitude,
|
||||
"longitude": self.longitude,
|
||||
"value": self.value,
|
||||
"unit": self.unit,
|
||||
"country": get_record_field(self, "country"),
|
||||
"city": get_record_field(self, "city"),
|
||||
"latitude": get_record_field(self, "latitude"),
|
||||
"longitude": get_record_field(self, "longitude"),
|
||||
"value": get_record_field(self, "value"),
|
||||
"unit": get_record_field(self, "unit"),
|
||||
"metadata": self.extra_data,
|
||||
"collected_at": self.collected_at.isoformat()
|
||||
if self.collected_at is not None
|
||||
@@ -77,4 +80,9 @@ class CollectedData(Base):
|
||||
"reference_date": self.reference_date.isoformat()
|
||||
if self.reference_date is not None
|
||||
else None,
|
||||
"is_current": self.is_current,
|
||||
"previous_record_id": self.previous_record_id,
|
||||
"change_type": self.change_type,
|
||||
"change_summary": self.change_summary,
|
||||
"deleted_at": self.deleted_at.isoformat() if self.deleted_at is not None else None,
|
||||
}
|
||||
|
||||
26
backend/app/models/data_snapshot.py
Normal file
26
backend/app/models/data_snapshot.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, JSON, String
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class DataSnapshot(Base):
|
||||
__tablename__ = "data_snapshots"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
datasource_id = Column(Integer, nullable=False, index=True)
|
||||
task_id = Column(Integer, ForeignKey("collection_tasks.id"), nullable=True, index=True)
|
||||
source = Column(String(100), nullable=False, index=True)
|
||||
snapshot_key = Column(String(100), nullable=True, index=True)
|
||||
reference_date = Column(DateTime(timezone=True), nullable=True)
|
||||
started_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
record_count = Column(Integer, default=0)
|
||||
status = Column(String(20), nullable=False, default="running")
|
||||
is_current = Column(Boolean, default=True, index=True)
|
||||
parent_snapshot_id = Column(Integer, ForeignKey("data_snapshots.id"), nullable=True, index=True)
|
||||
summary = Column(JSON, default={})
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DataSnapshot {self.id}: {self.source}/{self.status}>"
|
||||
@@ -12,6 +12,7 @@ class CollectionTask(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
datasource_id = Column(Integer, nullable=False, index=True)
|
||||
status = Column(String(20), nullable=False) # pending, running, success, failed, cancelled
|
||||
phase = Column(String(30), default="queued")
|
||||
started_at = Column(DateTime(timezone=True))
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
records_processed = Column(Integer, default=0)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Dict, Any, List
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.services.collectors.base import BaseCollector
|
||||
from app.core.data_sources import get_data_sources_config
|
||||
|
||||
from app.services.collectors.base import BaseCollector
|
||||
|
||||
|
||||
class ArcGISCableLandingRelationCollector(BaseCollector):
|
||||
@@ -18,45 +19,129 @@ class ArcGISCableLandingRelationCollector(BaseCollector):
|
||||
def base_url(self) -> str:
|
||||
if self._resolved_url:
|
||||
return self._resolved_url
|
||||
from app.core.data_sources import get_data_sources_config
|
||||
|
||||
config = get_data_sources_config()
|
||||
return config.get_yaml_url("arcgis_cable_landing_relation")
|
||||
|
||||
def _layer_url(self, layer_id: int) -> str:
|
||||
if "/FeatureServer/" not in self.base_url:
|
||||
return self.base_url
|
||||
prefix = self.base_url.split("/FeatureServer/")[0]
|
||||
return f"{prefix}/FeatureServer/{layer_id}/query"
|
||||
|
||||
async def _fetch_layer_attributes(
|
||||
self, client: httpx.AsyncClient, layer_id: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
response = await client.get(
|
||||
self._layer_url(layer_id),
|
||||
params={
|
||||
"where": "1=1",
|
||||
"outFields": "*",
|
||||
"returnGeometry": "false",
|
||||
"f": "json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [feature.get("attributes", {}) for feature in data.get("features", [])]
|
||||
|
||||
async def _fetch_relation_features(self, client: httpx.AsyncClient) -> List[Dict[str, Any]]:
|
||||
response = await client.get(
|
||||
self.base_url,
|
||||
params={
|
||||
"where": "1=1",
|
||||
"outFields": "*",
|
||||
"returnGeometry": "true",
|
||||
"f": "geojson",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("features", [])
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
params = {"where": "1=1", "outFields": "*", "returnGeometry": "true", "f": "geojson"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(self.base_url, params=params)
|
||||
response.raise_for_status()
|
||||
return self.parse_response(response.json())
|
||||
relation_features, landing_rows, cable_rows = await asyncio.gather(
|
||||
self._fetch_relation_features(client),
|
||||
self._fetch_layer_attributes(client, 1),
|
||||
self._fetch_layer_attributes(client, 2),
|
||||
)
|
||||
return self.parse_response(relation_features, landing_rows, cable_rows)
|
||||
|
||||
def parse_response(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
result = []
|
||||
def _build_landing_lookup(self, landing_rows: List[Dict[str, Any]]) -> Dict[int, Dict[str, Any]]:
|
||||
lookup: Dict[int, Dict[str, Any]] = {}
|
||||
for row in landing_rows:
|
||||
city_id = row.get("city_id")
|
||||
if city_id is None:
|
||||
continue
|
||||
lookup[int(city_id)] = {
|
||||
"landing_point_id": row.get("landing_point_id") or city_id,
|
||||
"landing_point_name": row.get("Name") or row.get("name") or "",
|
||||
"facility": row.get("facility") or "",
|
||||
"status": row.get("status") or "",
|
||||
"country": row.get("country") or "",
|
||||
}
|
||||
return lookup
|
||||
|
||||
features = data.get("features", [])
|
||||
for feature in features:
|
||||
def _build_cable_lookup(self, cable_rows: List[Dict[str, Any]]) -> Dict[int, Dict[str, Any]]:
|
||||
lookup: Dict[int, Dict[str, Any]] = {}
|
||||
for row in cable_rows:
|
||||
cable_id = row.get("cable_id")
|
||||
if cable_id is None:
|
||||
continue
|
||||
lookup[int(cable_id)] = {
|
||||
"cable_name": row.get("Name") or "",
|
||||
"status": row.get("status") or "active",
|
||||
}
|
||||
return lookup
|
||||
|
||||
def parse_response(
|
||||
self,
|
||||
relation_features: List[Dict[str, Any]],
|
||||
landing_rows: List[Dict[str, Any]],
|
||||
cable_rows: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
result: List[Dict[str, Any]] = []
|
||||
landing_lookup = self._build_landing_lookup(landing_rows)
|
||||
cable_lookup = self._build_cable_lookup(cable_rows)
|
||||
|
||||
for feature in relation_features:
|
||||
props = feature.get("properties", {})
|
||||
|
||||
try:
|
||||
city_id = props.get("city_id")
|
||||
cable_id = props.get("cable_id")
|
||||
landing_info = landing_lookup.get(int(city_id), {}) if city_id is not None else {}
|
||||
cable_info = cable_lookup.get(int(cable_id), {}) if cable_id is not None else {}
|
||||
|
||||
cable_name = cable_info.get("cable_name") or props.get("cable_name") or "Unknown"
|
||||
landing_point_name = (
|
||||
landing_info.get("landing_point_name")
|
||||
or props.get("landing_point_name")
|
||||
or "Unknown"
|
||||
)
|
||||
facility = landing_info.get("facility") or props.get("facility") or "-"
|
||||
status = cable_info.get("status") or landing_info.get("status") or props.get("status") or "-"
|
||||
country = landing_info.get("country") or props.get("country") or ""
|
||||
landing_point_id = landing_info.get("landing_point_id") or props.get("landing_point_id") or city_id
|
||||
|
||||
entry = {
|
||||
"source_id": f"arcgis_relation_{props.get('OBJECTID', props.get('id', ''))}",
|
||||
"name": f"{props.get('cable_name', 'Unknown')} - {props.get('landing_point_name', 'Unknown')}",
|
||||
"country": props.get("country", ""),
|
||||
"city": props.get("landing_point_name", ""),
|
||||
"name": f"{cable_name} - {landing_point_name}",
|
||||
"country": country,
|
||||
"city": landing_point_name,
|
||||
"latitude": str(props.get("latitude", "")) if props.get("latitude") else "",
|
||||
"longitude": str(props.get("longitude", "")) if props.get("longitude") else "",
|
||||
"value": "",
|
||||
"unit": "",
|
||||
"metadata": {
|
||||
"objectid": props.get("OBJECTID"),
|
||||
"city_id": props.get("city_id"),
|
||||
"cable_id": props.get("cable_id"),
|
||||
"cable_name": props.get("cable_name"),
|
||||
"landing_point_id": props.get("landing_point_id"),
|
||||
"landing_point_name": props.get("landing_point_name"),
|
||||
"facility": props.get("facility"),
|
||||
"status": props.get("status"),
|
||||
"city_id": city_id,
|
||||
"cable_id": cable_id,
|
||||
"cable_name": cable_name,
|
||||
"landing_point_id": landing_point_id,
|
||||
"landing_point_name": landing_point_name,
|
||||
"facility": facility,
|
||||
"status": status,
|
||||
},
|
||||
"reference_date": datetime.utcnow().strftime("%Y-%m-%d"),
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.collected_data_fields import build_dynamic_metadata, get_record_field
|
||||
from app.core.config import settings
|
||||
from app.core.countries import normalize_country
|
||||
|
||||
|
||||
class BaseCollector(ABC):
|
||||
@@ -39,6 +41,11 @@ class BaseCollector(ABC):
|
||||
records_processed / self._current_task.total_records
|
||||
) * 100
|
||||
|
||||
async def set_phase(self, phase: str):
|
||||
if self._current_task and self._db_session:
|
||||
self._current_task.phase = phase
|
||||
await self._db_session.commit()
|
||||
|
||||
@abstractmethod
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch raw data from source"""
|
||||
@@ -48,14 +55,87 @@ class BaseCollector(ABC):
|
||||
"""Transform raw data to internal format (default: pass through)"""
|
||||
return raw_data
|
||||
|
||||
def _parse_reference_date(self, value: Any) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return None
|
||||
|
||||
def _build_comparable_payload(self, record: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": getattr(record, "name", None),
|
||||
"title": getattr(record, "title", None),
|
||||
"description": getattr(record, "description", None),
|
||||
"country": get_record_field(record, "country"),
|
||||
"city": get_record_field(record, "city"),
|
||||
"latitude": get_record_field(record, "latitude"),
|
||||
"longitude": get_record_field(record, "longitude"),
|
||||
"value": get_record_field(record, "value"),
|
||||
"unit": get_record_field(record, "unit"),
|
||||
"metadata": getattr(record, "extra_data", None) or {},
|
||||
"reference_date": (
|
||||
getattr(record, "reference_date", None).isoformat()
|
||||
if getattr(record, "reference_date", None)
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
async def _create_snapshot(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: int,
|
||||
data: List[Dict[str, Any]],
|
||||
started_at: datetime,
|
||||
) -> int:
|
||||
from app.models.data_snapshot import DataSnapshot
|
||||
|
||||
reference_dates = [
|
||||
parsed
|
||||
for parsed in (self._parse_reference_date(item.get("reference_date")) for item in data)
|
||||
if parsed is not None
|
||||
]
|
||||
reference_date = max(reference_dates) if reference_dates else None
|
||||
|
||||
result = await db.execute(
|
||||
select(DataSnapshot)
|
||||
.where(DataSnapshot.source == self.name, DataSnapshot.is_current == True)
|
||||
.order_by(DataSnapshot.completed_at.desc().nullslast(), DataSnapshot.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
previous_snapshot = result.scalar_one_or_none()
|
||||
|
||||
snapshot = DataSnapshot(
|
||||
datasource_id=getattr(self, "_datasource_id", 1),
|
||||
task_id=task_id,
|
||||
source=self.name,
|
||||
snapshot_key=f"{self.name}:{task_id}",
|
||||
reference_date=reference_date,
|
||||
started_at=started_at,
|
||||
status="running",
|
||||
is_current=True,
|
||||
parent_snapshot_id=previous_snapshot.id if previous_snapshot else None,
|
||||
summary={},
|
||||
)
|
||||
db.add(snapshot)
|
||||
|
||||
if previous_snapshot:
|
||||
previous_snapshot.is_current = False
|
||||
|
||||
await db.commit()
|
||||
return snapshot.id
|
||||
|
||||
async def run(self, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Full pipeline: fetch -> transform -> save"""
|
||||
from app.services.collectors.registry import collector_registry
|
||||
from app.models.task import CollectionTask
|
||||
from app.models.collected_data import CollectedData
|
||||
from app.models.data_snapshot import DataSnapshot
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
datasource_id = getattr(self, "_datasource_id", 1)
|
||||
snapshot_id: Optional[int] = None
|
||||
|
||||
if not collector_registry.is_active(self.name):
|
||||
return {"status": "skipped", "reason": "Collector is disabled"}
|
||||
@@ -63,6 +143,7 @@ class BaseCollector(ABC):
|
||||
task = CollectionTask(
|
||||
datasource_id=datasource_id,
|
||||
status="running",
|
||||
phase="queued",
|
||||
started_at=start_time,
|
||||
)
|
||||
db.add(task)
|
||||
@@ -75,15 +156,20 @@ class BaseCollector(ABC):
|
||||
await self.resolve_url(db)
|
||||
|
||||
try:
|
||||
await self.set_phase("fetching")
|
||||
raw_data = await self.fetch()
|
||||
task.total_records = len(raw_data)
|
||||
await db.commit()
|
||||
|
||||
await self.set_phase("transforming")
|
||||
data = self.transform(raw_data)
|
||||
snapshot_id = await self._create_snapshot(db, task_id, data, start_time)
|
||||
|
||||
records_count = await self._save_data(db, data)
|
||||
await self.set_phase("saving")
|
||||
records_count = await self._save_data(db, data, task_id=task_id, snapshot_id=snapshot_id)
|
||||
|
||||
task.status = "success"
|
||||
task.phase = "completed"
|
||||
task.records_processed = records_count
|
||||
task.progress = 100.0
|
||||
task.completed_at = datetime.utcnow()
|
||||
@@ -97,8 +183,15 @@ class BaseCollector(ABC):
|
||||
}
|
||||
except Exception as e:
|
||||
task.status = "failed"
|
||||
task.phase = "failed"
|
||||
task.error_message = str(e)
|
||||
task.completed_at = datetime.utcnow()
|
||||
if snapshot_id is not None:
|
||||
snapshot = await db.get(DataSnapshot, snapshot_id)
|
||||
if snapshot:
|
||||
snapshot.status = "failed"
|
||||
snapshot.completed_at = datetime.utcnow()
|
||||
snapshot.summary = {"error": str(e)}
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
@@ -108,53 +201,163 @@ class BaseCollector(ABC):
|
||||
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
||||
}
|
||||
|
||||
async def _save_data(self, db: AsyncSession, data: List[Dict[str, Any]]) -> int:
|
||||
async def _save_data(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
data: List[Dict[str, Any]],
|
||||
task_id: Optional[int] = None,
|
||||
snapshot_id: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Save transformed data to database"""
|
||||
from app.models.collected_data import CollectedData
|
||||
from app.models.data_snapshot import DataSnapshot
|
||||
|
||||
if not data:
|
||||
if snapshot_id is not None:
|
||||
snapshot = await db.get(DataSnapshot, snapshot_id)
|
||||
if snapshot:
|
||||
snapshot.record_count = 0
|
||||
snapshot.summary = {"created": 0, "updated": 0, "unchanged": 0}
|
||||
snapshot.status = "success"
|
||||
snapshot.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
return 0
|
||||
|
||||
collected_at = datetime.utcnow()
|
||||
records_added = 0
|
||||
created_count = 0
|
||||
updated_count = 0
|
||||
unchanged_count = 0
|
||||
seen_entity_keys: set[str] = set()
|
||||
previous_current_keys: set[str] = set()
|
||||
|
||||
previous_current_result = await db.execute(
|
||||
select(CollectedData.entity_key).where(
|
||||
CollectedData.source == self.name,
|
||||
CollectedData.is_current == True,
|
||||
)
|
||||
)
|
||||
previous_current_keys = {row[0] for row in previous_current_result.fetchall() if row[0]}
|
||||
|
||||
for i, item in enumerate(data):
|
||||
print(
|
||||
f"DEBUG: Saving item {i}: name={item.get('name')}, metadata={item.get('metadata', 'NOT FOUND')}"
|
||||
)
|
||||
raw_metadata = item.get("metadata", {})
|
||||
extra_data = build_dynamic_metadata(
|
||||
raw_metadata,
|
||||
country=item.get("country"),
|
||||
city=item.get("city"),
|
||||
latitude=item.get("latitude"),
|
||||
longitude=item.get("longitude"),
|
||||
value=item.get("value"),
|
||||
unit=item.get("unit"),
|
||||
)
|
||||
normalized_country = normalize_country(item.get("country"))
|
||||
if normalized_country is not None:
|
||||
extra_data["country"] = normalized_country
|
||||
|
||||
if item.get("country") and normalized_country != item.get("country"):
|
||||
extra_data["raw_country"] = item.get("country")
|
||||
if normalized_country is None:
|
||||
extra_data["country_validation"] = "invalid"
|
||||
|
||||
source_id = item.get("source_id") or item.get("id")
|
||||
reference_date = (
|
||||
self._parse_reference_date(item.get("reference_date"))
|
||||
)
|
||||
source_id_str = str(source_id) if source_id is not None else None
|
||||
entity_key = f"{self.name}:{source_id_str}" if source_id_str else f"{self.name}:{i}"
|
||||
previous_record = None
|
||||
|
||||
if entity_key and entity_key not in seen_entity_keys:
|
||||
result = await db.execute(
|
||||
select(CollectedData)
|
||||
.where(
|
||||
CollectedData.source == self.name,
|
||||
CollectedData.entity_key == entity_key,
|
||||
CollectedData.is_current == True,
|
||||
)
|
||||
.order_by(CollectedData.collected_at.desc().nullslast(), CollectedData.id.desc())
|
||||
)
|
||||
previous_records = result.scalars().all()
|
||||
if previous_records:
|
||||
previous_record = previous_records[0]
|
||||
for old_record in previous_records:
|
||||
old_record.is_current = False
|
||||
|
||||
record = CollectedData(
|
||||
snapshot_id=snapshot_id,
|
||||
task_id=task_id,
|
||||
source=self.name,
|
||||
source_id=item.get("source_id") or item.get("id"),
|
||||
source_id=source_id_str,
|
||||
entity_key=entity_key,
|
||||
data_type=self.data_type,
|
||||
name=item.get("name"),
|
||||
title=item.get("title"),
|
||||
description=item.get("description"),
|
||||
country=item.get("country"),
|
||||
city=item.get("city"),
|
||||
latitude=str(item.get("latitude", ""))
|
||||
if item.get("latitude") is not None
|
||||
else None,
|
||||
longitude=str(item.get("longitude", ""))
|
||||
if item.get("longitude") is not None
|
||||
else None,
|
||||
value=item.get("value"),
|
||||
unit=item.get("unit"),
|
||||
extra_data=item.get("metadata", {}),
|
||||
extra_data=extra_data,
|
||||
collected_at=collected_at,
|
||||
reference_date=datetime.fromisoformat(
|
||||
item.get("reference_date").replace("Z", "+00:00")
|
||||
)
|
||||
if item.get("reference_date")
|
||||
else None,
|
||||
reference_date=reference_date,
|
||||
is_valid=1,
|
||||
is_current=True,
|
||||
previous_record_id=previous_record.id if previous_record else None,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
if previous_record is None:
|
||||
record.change_type = "created"
|
||||
record.change_summary = {}
|
||||
created_count += 1
|
||||
else:
|
||||
previous_payload = self._build_comparable_payload(previous_record)
|
||||
current_payload = self._build_comparable_payload(record)
|
||||
if current_payload == previous_payload:
|
||||
record.change_type = "unchanged"
|
||||
record.change_summary = {}
|
||||
unchanged_count += 1
|
||||
else:
|
||||
changed_fields = [
|
||||
key for key in current_payload.keys() if current_payload[key] != previous_payload.get(key)
|
||||
]
|
||||
record.change_type = "updated"
|
||||
record.change_summary = {"changed_fields": changed_fields}
|
||||
updated_count += 1
|
||||
|
||||
db.add(record)
|
||||
seen_entity_keys.add(entity_key)
|
||||
records_added += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
self.update_progress(i + 1)
|
||||
await db.commit()
|
||||
|
||||
if snapshot_id is not None:
|
||||
deleted_keys = previous_current_keys - seen_entity_keys
|
||||
await db.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE collected_data
|
||||
SET is_current = FALSE
|
||||
WHERE source = :source
|
||||
AND snapshot_id IS DISTINCT FROM :snapshot_id
|
||||
AND COALESCE(is_current, TRUE) = TRUE
|
||||
"""
|
||||
),
|
||||
{"source": self.name, "snapshot_id": snapshot_id},
|
||||
)
|
||||
snapshot = await db.get(DataSnapshot, snapshot_id)
|
||||
if snapshot:
|
||||
snapshot.record_count = records_added
|
||||
snapshot.status = "success"
|
||||
snapshot.completed_at = datetime.utcnow()
|
||||
snapshot.summary = {
|
||||
"created": created_count,
|
||||
"updated": updated_count,
|
||||
"unchanged": unchanged_count,
|
||||
"deleted": len(deleted_keys),
|
||||
}
|
||||
|
||||
await db.commit()
|
||||
self.update_progress(len(data))
|
||||
return records_added
|
||||
|
||||
@@ -76,7 +76,7 @@ class PeeringDBIXPCollector(HTTPCollector):
|
||||
print(f"Warning: PeeringDB collection failed after {max_retries} retries: {last_error}")
|
||||
return {}
|
||||
|
||||
async def collect(self) -> List[Dict[str, Any]]:
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Collect IXP data from PeeringDB with rate limit handling"""
|
||||
response_data = await self.fetch_with_retry()
|
||||
if not response_data:
|
||||
@@ -177,7 +177,7 @@ class PeeringDBNetworkCollector(HTTPCollector):
|
||||
print(f"Warning: PeeringDB collection failed after {max_retries} retries: {last_error}")
|
||||
return {}
|
||||
|
||||
async def collect(self) -> List[Dict[str, Any]]:
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Collect Network data from PeeringDB with rate limit handling"""
|
||||
response_data = await self.fetch_with_retry()
|
||||
if not response_data:
|
||||
@@ -280,7 +280,7 @@ class PeeringDBFacilityCollector(HTTPCollector):
|
||||
print(f"Warning: PeeringDB collection failed after {max_retries} retries: {last_error}")
|
||||
return {}
|
||||
|
||||
async def collect(self) -> List[Dict[str, Any]]:
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Collect Facility data from PeeringDB with rate limit handling"""
|
||||
response_data = await self.fetch_with_retry()
|
||||
if not response_data:
|
||||
|
||||
@@ -4,9 +4,9 @@ Collects data from TOP500 supercomputer rankings.
|
||||
https://top500.org/lists/top500/
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
from bs4 import BeautifulSoup
|
||||
import httpx
|
||||
|
||||
@@ -21,14 +21,108 @@ class TOP500Collector(BaseCollector):
|
||||
data_type = "supercomputer"
|
||||
|
||||
async def fetch(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch TOP500 data from website (scraping)"""
|
||||
# Get the latest list page
|
||||
"""Fetch TOP500 list data and enrich each row with detail-page metadata."""
|
||||
url = "https://top500.org/lists/top500/list/2025/11/"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return self.parse_response(response.text)
|
||||
entries = self.parse_response(response.text)
|
||||
|
||||
semaphore = asyncio.Semaphore(8)
|
||||
|
||||
async def enrich(entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
detail_url = entry.pop("_detail_url", "")
|
||||
if not detail_url:
|
||||
return entry
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
detail_response = await client.get(detail_url)
|
||||
detail_response.raise_for_status()
|
||||
entry["metadata"].update(self.parse_detail_response(detail_response.text))
|
||||
except Exception:
|
||||
entry["metadata"]["detail_fetch_failed"] = True
|
||||
return entry
|
||||
|
||||
return await asyncio.gather(*(enrich(entry) for entry in entries))
|
||||
|
||||
def _extract_system_fields(self, system_cell) -> Dict[str, str]:
|
||||
link = system_cell.find("a")
|
||||
system_name = link.get_text(" ", strip=True) if link else system_cell.get_text(" ", strip=True)
|
||||
detail_url = ""
|
||||
if link and link.get("href"):
|
||||
detail_url = f"https://top500.org{link.get('href')}"
|
||||
|
||||
manufacturer = ""
|
||||
if link and link.next_sibling:
|
||||
manufacturer = str(link.next_sibling).strip(" ,\n\t")
|
||||
|
||||
cell_text = system_cell.get_text("\n", strip=True)
|
||||
lines = [line.strip(" ,") for line in cell_text.splitlines() if line.strip()]
|
||||
|
||||
site = ""
|
||||
country = ""
|
||||
if lines:
|
||||
system_name = lines[0]
|
||||
if len(lines) >= 3:
|
||||
site = lines[-2]
|
||||
country = lines[-1]
|
||||
elif len(lines) == 2:
|
||||
country = lines[-1]
|
||||
|
||||
if not manufacturer and len(lines) >= 2:
|
||||
manufacturer = lines[1]
|
||||
|
||||
return {
|
||||
"name": system_name,
|
||||
"manufacturer": manufacturer,
|
||||
"site": site,
|
||||
"country": country,
|
||||
"detail_url": detail_url,
|
||||
}
|
||||
|
||||
def parse_detail_response(self, html: str) -> Dict[str, Any]:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
detail_table = soup.find("table", {"class": "table table-condensed"})
|
||||
if not detail_table:
|
||||
return {}
|
||||
|
||||
detail_map: Dict[str, Any] = {}
|
||||
label_aliases = {
|
||||
"Site": "site",
|
||||
"Manufacturer": "manufacturer",
|
||||
"Cores": "cores",
|
||||
"Processor": "processor",
|
||||
"Interconnect": "interconnect",
|
||||
"Installation Year": "installation_year",
|
||||
"Linpack Performance (Rmax)": "rmax",
|
||||
"Theoretical Peak (Rpeak)": "rpeak",
|
||||
"Nmax": "nmax",
|
||||
"HPCG": "hpcg",
|
||||
"Power": "power",
|
||||
"Power Measurement Level": "power_measurement_level",
|
||||
"Operating System": "operating_system",
|
||||
"Compiler": "compiler",
|
||||
"Math Library": "math_library",
|
||||
"MPI": "mpi",
|
||||
}
|
||||
|
||||
for row in detail_table.find_all("tr"):
|
||||
header = row.find("th")
|
||||
value_cell = row.find("td")
|
||||
if not header or not value_cell:
|
||||
continue
|
||||
|
||||
label = header.get_text(" ", strip=True).rstrip(":")
|
||||
key = label_aliases.get(label)
|
||||
if not key:
|
||||
continue
|
||||
|
||||
value = value_cell.get_text(" ", strip=True)
|
||||
detail_map[key] = value
|
||||
|
||||
return detail_map
|
||||
|
||||
def parse_response(self, html: str) -> List[Dict[str, Any]]:
|
||||
"""Parse TOP500 HTML response"""
|
||||
@@ -36,27 +130,26 @@ class TOP500Collector(BaseCollector):
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Find the table with TOP500 data
|
||||
table = soup.find("table", {"class": "top500-table"})
|
||||
if not table:
|
||||
# Try alternative table selector
|
||||
table = soup.find("table", {"id": "top500"})
|
||||
table = None
|
||||
for candidate in soup.find_all("table"):
|
||||
header_cells = [
|
||||
cell.get_text(" ", strip=True) for cell in candidate.select("thead th")
|
||||
]
|
||||
normalized_headers = [header.lower() for header in header_cells]
|
||||
if (
|
||||
"rank" in normalized_headers
|
||||
and "system" in normalized_headers
|
||||
and any("cores" in header for header in normalized_headers)
|
||||
and any("rmax" in header for header in normalized_headers)
|
||||
):
|
||||
table = candidate
|
||||
break
|
||||
|
||||
if not table:
|
||||
# Try to find any table with rank data
|
||||
tables = soup.find_all("table")
|
||||
for t in tables:
|
||||
if t.find(string=re.compile(r"Rank.*System.*Cores.*Rmax", re.I)):
|
||||
table = t
|
||||
break
|
||||
|
||||
if not table:
|
||||
# Fallback: try to extract data from any table
|
||||
tables = soup.find_all("table")
|
||||
if tables:
|
||||
table = tables[0]
|
||||
table = soup.find("table", {"class": "top500-table"}) or soup.find("table", {"id": "top500"})
|
||||
|
||||
if table:
|
||||
rows = table.find_all("tr")
|
||||
rows = table.select("tr")
|
||||
for row in rows[1:]: # Skip header row
|
||||
cells = row.find_all(["td", "th"])
|
||||
if len(cells) >= 6:
|
||||
@@ -68,43 +161,26 @@ class TOP500Collector(BaseCollector):
|
||||
|
||||
rank = int(rank_text)
|
||||
|
||||
# System name (may contain link)
|
||||
system_cell = cells[1]
|
||||
system_name = system_cell.get_text(strip=True)
|
||||
# Try to get full name from link title or data attribute
|
||||
link = system_cell.find("a")
|
||||
if link and link.get("title"):
|
||||
system_name = link.get("title")
|
||||
system_fields = self._extract_system_fields(system_cell)
|
||||
system_name = system_fields["name"]
|
||||
manufacturer = system_fields["manufacturer"]
|
||||
site = system_fields["site"]
|
||||
country = system_fields["country"]
|
||||
detail_url = system_fields["detail_url"]
|
||||
|
||||
# Country
|
||||
country_cell = cells[2]
|
||||
country = country_cell.get_text(strip=True)
|
||||
# Try to get country from data attribute or image alt
|
||||
img = country_cell.find("img")
|
||||
if img and img.get("alt"):
|
||||
country = img.get("alt")
|
||||
|
||||
# Extract location (city)
|
||||
city = ""
|
||||
location_text = country_cell.get_text(strip=True)
|
||||
if "(" in location_text and ")" in location_text:
|
||||
city = location_text.split("(")[0].strip()
|
||||
cores = cells[2].get_text(strip=True).replace(",", "")
|
||||
|
||||
# Cores
|
||||
cores = cells[3].get_text(strip=True).replace(",", "")
|
||||
|
||||
# Rmax
|
||||
rmax_text = cells[4].get_text(strip=True)
|
||||
rmax_text = cells[3].get_text(strip=True)
|
||||
rmax = self._parse_performance(rmax_text)
|
||||
|
||||
# Rpeak
|
||||
rpeak_text = cells[5].get_text(strip=True)
|
||||
rpeak_text = cells[4].get_text(strip=True)
|
||||
rpeak = self._parse_performance(rpeak_text)
|
||||
|
||||
# Power (optional)
|
||||
power = ""
|
||||
if len(cells) >= 7:
|
||||
power = cells[6].get_text(strip=True)
|
||||
if len(cells) >= 6:
|
||||
power = cells[5].get_text(strip=True).replace(",", "")
|
||||
|
||||
entry = {
|
||||
"source_id": f"top500_{rank}",
|
||||
@@ -117,10 +193,14 @@ class TOP500Collector(BaseCollector):
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": rank,
|
||||
"r_peak": rpeak,
|
||||
"power": power,
|
||||
"cores": cores,
|
||||
"rmax": rmax_text,
|
||||
"rpeak": rpeak_text,
|
||||
"power": power,
|
||||
"manufacturer": manufacturer,
|
||||
"site": site,
|
||||
},
|
||||
"_detail_url": detail_url,
|
||||
"reference_date": "2025-11-01",
|
||||
}
|
||||
data.append(entry)
|
||||
@@ -184,10 +264,15 @@ class TOP500Collector(BaseCollector):
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": 1,
|
||||
"r_peak": 2746.38,
|
||||
"power": 29581,
|
||||
"cores": 11039616,
|
||||
"cores": "11039616",
|
||||
"rmax": "1742.00",
|
||||
"rpeak": "2746.38",
|
||||
"power": "29581",
|
||||
"manufacturer": "HPE",
|
||||
"site": "DOE/NNSA/LLNL",
|
||||
"processor": "AMD 4th Gen EPYC 24C 1.8GHz",
|
||||
"interconnect": "Slingshot-11",
|
||||
"installation_year": "2025",
|
||||
},
|
||||
"reference_date": "2025-11-01",
|
||||
},
|
||||
@@ -202,10 +287,12 @@ class TOP500Collector(BaseCollector):
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": 2,
|
||||
"r_peak": 2055.72,
|
||||
"power": 24607,
|
||||
"cores": 9066176,
|
||||
"cores": "9066176",
|
||||
"rmax": "1353.00",
|
||||
"rpeak": "2055.72",
|
||||
"power": "24607",
|
||||
"manufacturer": "HPE",
|
||||
"site": "DOE/SC/Oak Ridge National Laboratory",
|
||||
},
|
||||
"reference_date": "2025-11-01",
|
||||
},
|
||||
@@ -220,9 +307,10 @@ class TOP500Collector(BaseCollector):
|
||||
"unit": "PFlop/s",
|
||||
"metadata": {
|
||||
"rank": 3,
|
||||
"r_peak": 1980.01,
|
||||
"power": 38698,
|
||||
"cores": 9264128,
|
||||
"cores": "9264128",
|
||||
"rmax": "1012.00",
|
||||
"rpeak": "1980.01",
|
||||
"power": "38698",
|
||||
"manufacturer": "Intel",
|
||||
},
|
||||
"reference_date": "2025-11-01",
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import select
|
||||
|
||||
from app.db.session import async_session_factory
|
||||
from app.models.datasource import DataSource
|
||||
from app.models.task import CollectionTask
|
||||
from app.services.collectors.registry import collector_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -89,6 +90,35 @@ async def run_collector_task(collector_name: str):
|
||||
logger.exception("Collector %s failed: %s", collector_name, exc)
|
||||
|
||||
|
||||
async def cleanup_stale_running_tasks(max_age_hours: int = 2) -> int:
|
||||
"""Mark stale running tasks as failed after restarts or collector hangs."""
|
||||
cutoff = datetime.utcnow() - timedelta(hours=max_age_hours)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(CollectionTask).where(
|
||||
CollectionTask.status == "running",
|
||||
CollectionTask.started_at.is_not(None),
|
||||
CollectionTask.started_at < cutoff,
|
||||
)
|
||||
)
|
||||
stale_tasks = result.scalars().all()
|
||||
|
||||
for task in stale_tasks:
|
||||
task.status = "failed"
|
||||
task.phase = "failed"
|
||||
task.completed_at = datetime.utcnow()
|
||||
existing_error = (task.error_message or "").strip()
|
||||
cleanup_error = "Marked failed automatically after stale running task cleanup"
|
||||
task.error_message = f"{existing_error}\n{cleanup_error}".strip() if existing_error else cleanup_error
|
||||
|
||||
if stale_tasks:
|
||||
await db.commit()
|
||||
logger.warning("Cleaned up %s stale running collection task(s)", len(stale_tasks))
|
||||
|
||||
return len(stale_tasks)
|
||||
|
||||
|
||||
def start_scheduler() -> None:
|
||||
"""Start the scheduler."""
|
||||
if not scheduler.running:
|
||||
@@ -144,6 +174,19 @@ def get_scheduler_jobs() -> list[Dict[str, Any]]:
|
||||
return jobs
|
||||
|
||||
|
||||
async def get_latest_task_id_for_datasource(datasource_id: int) -> Optional[int]:
|
||||
from app.models.task import CollectionTask
|
||||
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(CollectionTask.id)
|
||||
.where(CollectionTask.datasource_id == datasource_id)
|
||||
.order_by(CollectionTask.created_at.desc(), CollectionTask.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def run_collector_now(collector_name: str) -> bool:
|
||||
"""Run a collector immediately (not scheduled)."""
|
||||
collector = collector_registry.get(collector_name)
|
||||
|
||||
Reference in New Issue
Block a user