from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, status, Response from fastapi.responses import StreamingResponse from sqlalchemy import select, func, text from sqlalchemy.ext.asyncio import AsyncSession import json import csv import io from app.core.collected_data_fields import get_metadata_field from app.core.countries import COUNTRY_OPTIONS, get_country_search_variants, normalize_country from app.db.session import get_db from app.models.user import User from app.core.security import get_current_user from app.models.collected_data import CollectedData router = APIRouter() COUNTRY_SQL = "metadata->>'country'" SEARCHABLE_SQL = [ "name", "title", "description", "source", "data_type", "source_id", "metadata::text", ] def parse_multi_values(value: Optional[str]) -> list[str]: if not value: return [] return [item.strip() for item in value.split(",") if item.strip()] def build_in_condition(field_sql: str, values: list[str], param_prefix: str, params: dict) -> str: placeholders = [] for index, value in enumerate(values): key = f"{param_prefix}_{index}" params[key] = value placeholders.append(f":{key}") return f"{field_sql} IN ({', '.join(placeholders)})" def build_search_condition(search: Optional[str], params: dict) -> Optional[str]: if not search: return None normalized = search.strip() if not normalized: return None search_terms = [normalized] for variant in get_country_search_variants(normalized): if variant.casefold() not in {term.casefold() for term in search_terms}: search_terms.append(variant) conditions = [] for index, term in enumerate(search_terms): params[f"search_{index}"] = f"%{term}%" conditions.extend(f"{field} ILIKE :search_{index}" for field in SEARCHABLE_SQL) params["search_exact"] = normalized params["search_prefix"] = f"{normalized}%" canonical_variants = get_country_search_variants(normalized) canonical = canonical_variants[0] if canonical_variants else None params["country_search_exact"] = canonical or normalized params["country_search_prefix"] = f"{(canonical or normalized)}%" return "(" + " OR ".join(conditions) + ")" def build_search_rank_sql(search: Optional[str]) -> str: if not search or not search.strip(): return "0" return """ CASE WHEN name ILIKE :search_exact THEN 700 WHEN name ILIKE :search_prefix THEN 600 WHEN title ILIKE :search_exact THEN 500 WHEN title ILIKE :search_prefix THEN 400 WHEN metadata->>'country' ILIKE :country_search_exact THEN 380 WHEN metadata->>'country' ILIKE :country_search_prefix THEN 340 WHEN source_id ILIKE :search_exact THEN 350 WHEN source ILIKE :search_exact THEN 300 WHEN data_type ILIKE :search_exact THEN 250 WHEN description ILIKE :search_0 THEN 150 WHEN metadata::text ILIKE :search_0 THEN 100 WHEN title ILIKE :search_0 THEN 80 WHEN name ILIKE :search_0 THEN 60 WHEN source ILIKE :search_0 THEN 40 WHEN data_type ILIKE :search_0 THEN 30 WHEN source_id ILIKE :search_0 THEN 20 ELSE 0 END """ def serialize_collected_row(row) -> dict: metadata = row[7] return { "id": row[0], "source": row[1], "source_id": row[2], "data_type": row[3], "name": row[4], "title": row[5], "description": row[6], "country": get_metadata_field(metadata, "country"), "city": get_metadata_field(metadata, "city"), "latitude": get_metadata_field(metadata, "latitude"), "longitude": get_metadata_field(metadata, "longitude"), "value": get_metadata_field(metadata, "value"), "unit": get_metadata_field(metadata, "unit"), "metadata": metadata, "cores": get_metadata_field(metadata, "cores"), "rmax": get_metadata_field(metadata, "rmax"), "rpeak": get_metadata_field(metadata, "rpeak"), "power": get_metadata_field(metadata, "power"), "collected_at": row[8].isoformat() if row[8] else None, "reference_date": row[9].isoformat() if row[9] else None, "is_valid": row[10], } @router.get("") async def list_collected_data( mode: str = Query("current", description="查询模式: current/history"), source: Optional[str] = Query(None, description="数据源过滤"), data_type: Optional[str] = Query(None, description="数据类型过滤"), country: Optional[str] = Query(None, description="国家过滤"), search: Optional[str] = Query(None, description="搜索名称"), page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """查询采集的数据列表""" normalized_country = normalize_country(country) if country else None source_values = parse_multi_values(source) data_type_values = parse_multi_values(data_type) # Build WHERE clause conditions = [] params = {} if mode != "history": conditions.append("COALESCE(is_current, TRUE) = TRUE") if source_values: conditions.append(build_in_condition("source", source_values, "source", params)) if data_type_values: conditions.append(build_in_condition("data_type", data_type_values, "data_type", params)) if normalized_country: conditions.append(f"{COUNTRY_SQL} = :country") params["country"] = normalized_country search_condition = build_search_condition(search, params) if search_condition: conditions.append(search_condition) where_sql = " AND ".join(conditions) if conditions else "1=1" search_rank_sql = build_search_rank_sql(search) # Calculate offset offset = (page - 1) * page_size # Query total count count_query = text(f"SELECT COUNT(*) FROM collected_data WHERE {where_sql}") count_result = await db.execute(count_query, params) total = count_result.scalar() # Query data query = text(f""" SELECT id, source, source_id, data_type, name, title, description, metadata, collected_at, reference_date, is_valid, {search_rank_sql} AS search_rank FROM collected_data WHERE {where_sql} ORDER BY search_rank DESC, collected_at DESC LIMIT :limit OFFSET :offset """) params["limit"] = page_size params["offset"] = offset result = await db.execute(query, params) rows = result.fetchall() data = [] for row in rows: data.append(serialize_collected_row(row[:11])) return { "total": total, "page": page, "page_size": page_size, "data": data, } @router.get("/summary") async def get_data_summary( mode: str = Query("current", description="查询模式: current/history"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取数据汇总统计""" where_sql = "WHERE COALESCE(is_current, TRUE) = TRUE" if mode != "history" else "" # By source and data_type result = await db.execute( text(""" SELECT source, data_type, COUNT(*) as count FROM collected_data """ + where_sql + """ GROUP BY source, data_type ORDER BY source, data_type """) ) rows = result.fetchall() by_source = {} total = 0 for row in rows: source = row[0] data_type = row[1] count = row[2] if source not in by_source: by_source[source] = {} by_source[source][data_type] = count total += count # Total by source source_totals = await db.execute( text(""" SELECT source, COUNT(*) as count FROM collected_data """ + where_sql + """ GROUP BY source ORDER BY count DESC """) ) source_rows = source_totals.fetchall() return { "total_records": total, "by_source": by_source, "source_totals": [{"source": row[0], "count": row[1]} for row in source_rows], } @router.get("/sources") async def get_data_sources( mode: str = Query("current", description="查询模式: current/history"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取所有数据源列表""" result = await db.execute( text(""" SELECT DISTINCT source FROM collected_data """ + ("WHERE COALESCE(is_current, TRUE) = TRUE " if mode != "history" else "") + """ ORDER BY source """) ) rows = result.fetchall() return { "sources": [row[0] for row in rows], } @router.get("/types") async def get_data_types( mode: str = Query("current", description="查询模式: current/history"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取所有数据类型列表""" result = await db.execute( text(""" SELECT DISTINCT data_type FROM collected_data """ + ("WHERE COALESCE(is_current, TRUE) = TRUE " if mode != "history" else "") + """ ORDER BY data_type """) ) rows = result.fetchall() return { "data_types": [row[0] for row in rows], } @router.get("/countries") async def get_countries( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取所有国家列表""" return { "countries": COUNTRY_OPTIONS, } @router.get("/{data_id}") async def get_collected_data( data_id: int, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取单条采集数据详情""" result = await db.execute( text(""" SELECT id, source, source_id, data_type, name, title, description, metadata, collected_at, reference_date, is_valid FROM collected_data WHERE id = :id """), {"id": data_id}, ) row = result.fetchone() if not row: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="数据不存在", ) return serialize_collected_row(row) def build_where_clause( source: Optional[str], data_type: Optional[str], country: Optional[str], search: Optional[str] ): """Build WHERE clause and params for queries""" conditions = [] params = {} source_values = parse_multi_values(source) data_type_values = parse_multi_values(data_type) if source_values: conditions.append(build_in_condition("source", source_values, "source", params)) if data_type_values: conditions.append(build_in_condition("data_type", data_type_values, "data_type", params)) normalized_country = normalize_country(country) if country else None if normalized_country: conditions.append(f"{COUNTRY_SQL} = :country") params["country"] = normalized_country search_condition = build_search_condition(search, params) if search_condition: conditions.append(search_condition) where_sql = " AND ".join(conditions) if conditions else "1=1" return where_sql, params @router.get("/export/json") async def export_json( mode: str = Query("current", description="查询模式: current/history"), source: Optional[str] = Query(None, description="数据源过滤"), data_type: Optional[str] = Query(None, description="数据类型过滤"), country: Optional[str] = Query(None, description="国家过滤"), search: Optional[str] = Query(None, description="搜索名称"), limit: int = Query(10000, ge=1, le=50000, description="最大导出数量"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """导出数据为 JSON 格式""" where_sql, params = build_where_clause(source, data_type, country, search) if mode != "history": where_sql = f"({where_sql}) AND COALESCE(is_current, TRUE) = TRUE" params["limit"] = limit query = text(f""" SELECT id, source, source_id, data_type, name, title, description, metadata, collected_at, reference_date, is_valid FROM collected_data WHERE {where_sql} ORDER BY collected_at DESC LIMIT :limit """) result = await db.execute(query, params) rows = result.fetchall() data = [] for row in rows: data.append(serialize_collected_row(row)) json_str = json.dumps({"data": data, "total": len(data)}, ensure_ascii=False, indent=2) return StreamingResponse( io.StringIO(json_str), media_type="application/json", headers={ "Content-Disposition": f"attachment; filename=collected_data_{source or 'all'}.json" }, ) @router.get("/export/csv") async def export_csv( mode: str = Query("current", description="查询模式: current/history"), source: Optional[str] = Query(None, description="数据源过滤"), data_type: Optional[str] = Query(None, description="数据类型过滤"), country: Optional[str] = Query(None, description="国家过滤"), search: Optional[str] = Query(None, description="搜索名称"), limit: int = Query(10000, ge=1, le=50000, description="最大导出数量"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """导出数据为 CSV 格式""" where_sql, params = build_where_clause(source, data_type, country, search) if mode != "history": where_sql = f"({where_sql}) AND COALESCE(is_current, TRUE) = TRUE" params["limit"] = limit query = text(f""" SELECT id, source, source_id, data_type, name, title, description, metadata, collected_at, reference_date, is_valid FROM collected_data WHERE {where_sql} ORDER BY collected_at DESC LIMIT :limit """) result = await db.execute(query, params) rows = result.fetchall() output = io.StringIO() writer = csv.writer(output) # Write header writer.writerow( [ "ID", "Source", "Source ID", "Type", "Name", "Title", "Description", "Country", "City", "Latitude", "Longitude", "Value", "Unit", "Metadata", "Collected At", "Reference Date", "Is Valid", ] ) # Write data for row in rows: writer.writerow( [ row[0], row[1], row[2], row[3], row[4], row[5], row[6], get_metadata_field(row[7], "country"), get_metadata_field(row[7], "city"), get_metadata_field(row[7], "latitude"), get_metadata_field(row[7], "longitude"), get_metadata_field(row[7], "value"), get_metadata_field(row[7], "unit"), json.dumps(row[7]) if row[7] else "", row[8].isoformat() if row[8] else "", row[9].isoformat() if row[9] else "", row[10], ] ) return StreamingResponse( io.StringIO(output.getvalue()), media_type="text/csv", headers={ "Content-Disposition": f"attachment; filename=collected_data_{source or 'all'}.csv" }, )