Files
planet/backend/app/api/v1/collected_data.py
2026-03-25 17:19:10 +08:00

498 lines
16 KiB
Python

from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import select, func, text
from sqlalchemy.ext.asyncio import AsyncSession
import json
import csv
import io
from app.core.collected_data_fields import get_metadata_field
from app.core.countries import COUNTRY_OPTIONS, get_country_search_variants, normalize_country
from app.db.session import get_db
from app.models.user import User
from app.core.security import get_current_user
from app.models.collected_data import CollectedData
router = APIRouter()
COUNTRY_SQL = "metadata->>'country'"
SEARCHABLE_SQL = [
"name",
"title",
"description",
"source",
"data_type",
"source_id",
"metadata::text",
]
def parse_multi_values(value: Optional[str]) -> list[str]:
if not value:
return []
return [item.strip() for item in value.split(",") if item.strip()]
def build_in_condition(field_sql: str, values: list[str], param_prefix: str, params: dict) -> str:
placeholders = []
for index, value in enumerate(values):
key = f"{param_prefix}_{index}"
params[key] = value
placeholders.append(f":{key}")
return f"{field_sql} IN ({', '.join(placeholders)})"
def build_search_condition(search: Optional[str], params: dict) -> Optional[str]:
if not search:
return None
normalized = search.strip()
if not normalized:
return None
search_terms = [normalized]
for variant in get_country_search_variants(normalized):
if variant.casefold() not in {term.casefold() for term in search_terms}:
search_terms.append(variant)
conditions = []
for index, term in enumerate(search_terms):
params[f"search_{index}"] = f"%{term}%"
conditions.extend(f"{field} ILIKE :search_{index}" for field in SEARCHABLE_SQL)
params["search_exact"] = normalized
params["search_prefix"] = f"{normalized}%"
canonical_variants = get_country_search_variants(normalized)
canonical = canonical_variants[0] if canonical_variants else None
params["country_search_exact"] = canonical or normalized
params["country_search_prefix"] = f"{(canonical or normalized)}%"
return "(" + " OR ".join(conditions) + ")"
def build_search_rank_sql(search: Optional[str]) -> str:
if not search or not search.strip():
return "0"
return """
CASE
WHEN name ILIKE :search_exact THEN 700
WHEN name ILIKE :search_prefix THEN 600
WHEN title ILIKE :search_exact THEN 500
WHEN title ILIKE :search_prefix THEN 400
WHEN metadata->>'country' ILIKE :country_search_exact THEN 380
WHEN metadata->>'country' ILIKE :country_search_prefix THEN 340
WHEN source_id ILIKE :search_exact THEN 350
WHEN source ILIKE :search_exact THEN 300
WHEN data_type ILIKE :search_exact THEN 250
WHEN description ILIKE :search_0 THEN 150
WHEN metadata::text ILIKE :search_0 THEN 100
WHEN title ILIKE :search_0 THEN 80
WHEN name ILIKE :search_0 THEN 60
WHEN source ILIKE :search_0 THEN 40
WHEN data_type ILIKE :search_0 THEN 30
WHEN source_id ILIKE :search_0 THEN 20
ELSE 0
END
"""
def serialize_collected_row(row) -> dict:
metadata = row[7]
return {
"id": row[0],
"source": row[1],
"source_id": row[2],
"data_type": row[3],
"name": row[4],
"title": row[5],
"description": row[6],
"country": get_metadata_field(metadata, "country"),
"city": get_metadata_field(metadata, "city"),
"latitude": get_metadata_field(metadata, "latitude"),
"longitude": get_metadata_field(metadata, "longitude"),
"value": get_metadata_field(metadata, "value"),
"unit": get_metadata_field(metadata, "unit"),
"metadata": metadata,
"cores": get_metadata_field(metadata, "cores"),
"rmax": get_metadata_field(metadata, "rmax"),
"rpeak": get_metadata_field(metadata, "rpeak"),
"power": get_metadata_field(metadata, "power"),
"collected_at": row[8].isoformat() if row[8] else None,
"reference_date": row[9].isoformat() if row[9] else None,
"is_valid": row[10],
}
@router.get("")
async def list_collected_data(
mode: str = Query("current", description="查询模式: current/history"),
source: Optional[str] = Query(None, description="数据源过滤"),
data_type: Optional[str] = Query(None, description="数据类型过滤"),
country: Optional[str] = Query(None, description="国家过滤"),
search: Optional[str] = Query(None, description="搜索名称"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""查询采集的数据列表"""
normalized_country = normalize_country(country) if country else None
source_values = parse_multi_values(source)
data_type_values = parse_multi_values(data_type)
# Build WHERE clause
conditions = []
params = {}
if mode != "history":
conditions.append("COALESCE(is_current, TRUE) = TRUE")
if source_values:
conditions.append(build_in_condition("source", source_values, "source", params))
if data_type_values:
conditions.append(build_in_condition("data_type", data_type_values, "data_type", params))
if normalized_country:
conditions.append(f"{COUNTRY_SQL} = :country")
params["country"] = normalized_country
search_condition = build_search_condition(search, params)
if search_condition:
conditions.append(search_condition)
where_sql = " AND ".join(conditions) if conditions else "1=1"
search_rank_sql = build_search_rank_sql(search)
# Calculate offset
offset = (page - 1) * page_size
# Query total count
count_query = text(f"SELECT COUNT(*) FROM collected_data WHERE {where_sql}")
count_result = await db.execute(count_query, params)
total = count_result.scalar()
# Query data
query = text(f"""
SELECT id, source, source_id, data_type, name, title, description,
metadata, collected_at, reference_date, is_valid,
{search_rank_sql} AS search_rank
FROM collected_data
WHERE {where_sql}
ORDER BY search_rank DESC, collected_at DESC
LIMIT :limit OFFSET :offset
""")
params["limit"] = page_size
params["offset"] = offset
result = await db.execute(query, params)
rows = result.fetchall()
data = []
for row in rows:
data.append(serialize_collected_row(row[:11]))
return {
"total": total,
"page": page,
"page_size": page_size,
"data": data,
}
@router.get("/summary")
async def get_data_summary(
mode: str = Query("current", description="查询模式: current/history"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取数据汇总统计"""
where_sql = "WHERE COALESCE(is_current, TRUE) = TRUE" if mode != "history" else ""
# By source and data_type
result = await db.execute(
text("""
SELECT source, data_type, COUNT(*) as count
FROM collected_data
""" + where_sql + """
GROUP BY source, data_type
ORDER BY source, data_type
""")
)
rows = result.fetchall()
by_source = {}
total = 0
for row in rows:
source = row[0]
data_type = row[1]
count = row[2]
if source not in by_source:
by_source[source] = {}
by_source[source][data_type] = count
total += count
# Total by source
source_totals = await db.execute(
text("""
SELECT source, COUNT(*) as count
FROM collected_data
""" + where_sql + """
GROUP BY source
ORDER BY count DESC
""")
)
source_rows = source_totals.fetchall()
return {
"total_records": total,
"by_source": by_source,
"source_totals": [{"source": row[0], "count": row[1]} for row in source_rows],
}
@router.get("/sources")
async def get_data_sources(
mode: str = Query("current", description="查询模式: current/history"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取所有数据源列表"""
result = await db.execute(
text("""
SELECT DISTINCT source FROM collected_data
""" + ("WHERE COALESCE(is_current, TRUE) = TRUE " if mode != "history" else "") + """
ORDER BY source
""")
)
rows = result.fetchall()
return {
"sources": [row[0] for row in rows],
}
@router.get("/types")
async def get_data_types(
mode: str = Query("current", description="查询模式: current/history"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取所有数据类型列表"""
result = await db.execute(
text("""
SELECT DISTINCT data_type FROM collected_data
""" + ("WHERE COALESCE(is_current, TRUE) = TRUE " if mode != "history" else "") + """
ORDER BY data_type
""")
)
rows = result.fetchall()
return {
"data_types": [row[0] for row in rows],
}
@router.get("/countries")
async def get_countries(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取所有国家列表"""
return {
"countries": COUNTRY_OPTIONS,
}
@router.get("/{data_id}")
async def get_collected_data(
data_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取单条采集数据详情"""
result = await db.execute(
text("""
SELECT id, source, source_id, data_type, name, title, description,
metadata, collected_at, reference_date, is_valid
FROM collected_data
WHERE id = :id
"""),
{"id": data_id},
)
row = result.fetchone()
if not row:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="数据不存在",
)
return serialize_collected_row(row)
def build_where_clause(
source: Optional[str], data_type: Optional[str], country: Optional[str], search: Optional[str]
):
"""Build WHERE clause and params for queries"""
conditions = []
params = {}
source_values = parse_multi_values(source)
data_type_values = parse_multi_values(data_type)
if source_values:
conditions.append(build_in_condition("source", source_values, "source", params))
if data_type_values:
conditions.append(build_in_condition("data_type", data_type_values, "data_type", params))
normalized_country = normalize_country(country) if country else None
if normalized_country:
conditions.append(f"{COUNTRY_SQL} = :country")
params["country"] = normalized_country
search_condition = build_search_condition(search, params)
if search_condition:
conditions.append(search_condition)
where_sql = " AND ".join(conditions) if conditions else "1=1"
return where_sql, params
@router.get("/export/json")
async def export_json(
mode: str = Query("current", description="查询模式: current/history"),
source: Optional[str] = Query(None, description="数据源过滤"),
data_type: Optional[str] = Query(None, description="数据类型过滤"),
country: Optional[str] = Query(None, description="国家过滤"),
search: Optional[str] = Query(None, description="搜索名称"),
limit: int = Query(10000, ge=1, le=50000, description="最大导出数量"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""导出数据为 JSON 格式"""
where_sql, params = build_where_clause(source, data_type, country, search)
if mode != "history":
where_sql = f"({where_sql}) AND COALESCE(is_current, TRUE) = TRUE"
params["limit"] = limit
query = text(f"""
SELECT id, source, source_id, data_type, name, title, description,
metadata, collected_at, reference_date, is_valid
FROM collected_data
WHERE {where_sql}
ORDER BY collected_at DESC
LIMIT :limit
""")
result = await db.execute(query, params)
rows = result.fetchall()
data = []
for row in rows:
data.append(serialize_collected_row(row))
json_str = json.dumps({"data": data, "total": len(data)}, ensure_ascii=False, indent=2)
return StreamingResponse(
io.StringIO(json_str),
media_type="application/json",
headers={
"Content-Disposition": f"attachment; filename=collected_data_{source or 'all'}.json"
},
)
@router.get("/export/csv")
async def export_csv(
mode: str = Query("current", description="查询模式: current/history"),
source: Optional[str] = Query(None, description="数据源过滤"),
data_type: Optional[str] = Query(None, description="数据类型过滤"),
country: Optional[str] = Query(None, description="国家过滤"),
search: Optional[str] = Query(None, description="搜索名称"),
limit: int = Query(10000, ge=1, le=50000, description="最大导出数量"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""导出数据为 CSV 格式"""
where_sql, params = build_where_clause(source, data_type, country, search)
if mode != "history":
where_sql = f"({where_sql}) AND COALESCE(is_current, TRUE) = TRUE"
params["limit"] = limit
query = text(f"""
SELECT id, source, source_id, data_type, name, title, description,
metadata, collected_at, reference_date, is_valid
FROM collected_data
WHERE {where_sql}
ORDER BY collected_at DESC
LIMIT :limit
""")
result = await db.execute(query, params)
rows = result.fetchall()
output = io.StringIO()
writer = csv.writer(output)
# Write header
writer.writerow(
[
"ID",
"Source",
"Source ID",
"Type",
"Name",
"Title",
"Description",
"Country",
"City",
"Latitude",
"Longitude",
"Value",
"Unit",
"Metadata",
"Collected At",
"Reference Date",
"Is Valid",
]
)
# Write data
for row in rows:
writer.writerow(
[
row[0],
row[1],
row[2],
row[3],
row[4],
row[5],
row[6],
get_metadata_field(row[7], "country"),
get_metadata_field(row[7], "city"),
get_metadata_field(row[7], "latitude"),
get_metadata_field(row[7], "longitude"),
get_metadata_field(row[7], "value"),
get_metadata_field(row[7], "unit"),
json.dumps(row[7]) if row[7] else "",
row[8].isoformat() if row[8] else "",
row[9].isoformat() if row[9] else "",
row[10],
]
)
return StreamingResponse(
io.StringIO(output.getvalue()),
media_type="text/csv",
headers={
"Content-Disposition": f"attachment; filename=collected_data_{source or 'all'}.csv"
},
)