Files
planet/backend/app/services/collectors/base.py
2026-03-25 17:19:10 +08:00

414 lines
15 KiB
Python

"""Base collector class for all data sources"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional
from datetime import datetime
import httpx
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.collected_data_fields import build_dynamic_metadata, get_record_field
from app.core.config import settings
from app.core.countries import normalize_country
class BaseCollector(ABC):
"""Abstract base class for data collectors"""
name: str = "base_collector"
priority: str = "P1"
module: str = "L1"
frequency_hours: int = 4
data_type: str = "generic"
def __init__(self):
self._current_task = None
self._db_session = None
self._datasource_id = 1
self._resolved_url: Optional[str] = None
async def resolve_url(self, db: AsyncSession) -> None:
from app.core.data_sources import get_data_sources_config
config = get_data_sources_config()
self._resolved_url = await config.get_url(self.name, db)
def update_progress(self, records_processed: int):
"""Update task progress - call this during data processing"""
if self._current_task and self._db_session and self._current_task.total_records > 0:
self._current_task.records_processed = records_processed
self._current_task.progress = (
records_processed / self._current_task.total_records
) * 100
async def set_phase(self, phase: str):
if self._current_task and self._db_session:
self._current_task.phase = phase
await self._db_session.commit()
@abstractmethod
async def fetch(self) -> List[Dict[str, Any]]:
"""Fetch raw data from source"""
pass
def transform(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Transform raw data to internal format (default: pass through)"""
return raw_data
def _parse_reference_date(self, value: Any) -> Optional[datetime]:
if not value:
return None
if isinstance(value, datetime):
return value
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return None
def _build_comparable_payload(self, record: Any) -> Dict[str, Any]:
return {
"name": getattr(record, "name", None),
"title": getattr(record, "title", None),
"description": getattr(record, "description", None),
"country": get_record_field(record, "country"),
"city": get_record_field(record, "city"),
"latitude": get_record_field(record, "latitude"),
"longitude": get_record_field(record, "longitude"),
"value": get_record_field(record, "value"),
"unit": get_record_field(record, "unit"),
"metadata": getattr(record, "extra_data", None) or {},
"reference_date": (
getattr(record, "reference_date", None).isoformat()
if getattr(record, "reference_date", None)
else None
),
}
async def _create_snapshot(
self,
db: AsyncSession,
task_id: int,
data: List[Dict[str, Any]],
started_at: datetime,
) -> int:
from app.models.data_snapshot import DataSnapshot
reference_dates = [
parsed
for parsed in (self._parse_reference_date(item.get("reference_date")) for item in data)
if parsed is not None
]
reference_date = max(reference_dates) if reference_dates else None
result = await db.execute(
select(DataSnapshot)
.where(DataSnapshot.source == self.name, DataSnapshot.is_current == True)
.order_by(DataSnapshot.completed_at.desc().nullslast(), DataSnapshot.id.desc())
.limit(1)
)
previous_snapshot = result.scalar_one_or_none()
snapshot = DataSnapshot(
datasource_id=getattr(self, "_datasource_id", 1),
task_id=task_id,
source=self.name,
snapshot_key=f"{self.name}:{task_id}",
reference_date=reference_date,
started_at=started_at,
status="running",
is_current=True,
parent_snapshot_id=previous_snapshot.id if previous_snapshot else None,
summary={},
)
db.add(snapshot)
if previous_snapshot:
previous_snapshot.is_current = False
await db.commit()
return snapshot.id
async def run(self, db: AsyncSession) -> Dict[str, Any]:
"""Full pipeline: fetch -> transform -> save"""
from app.services.collectors.registry import collector_registry
from app.models.task import CollectionTask
from app.models.data_snapshot import DataSnapshot
start_time = datetime.utcnow()
datasource_id = getattr(self, "_datasource_id", 1)
snapshot_id: Optional[int] = None
if not collector_registry.is_active(self.name):
return {"status": "skipped", "reason": "Collector is disabled"}
task = CollectionTask(
datasource_id=datasource_id,
status="running",
phase="queued",
started_at=start_time,
)
db.add(task)
await db.commit()
task_id = task.id
self._current_task = task
self._db_session = db
await self.resolve_url(db)
try:
await self.set_phase("fetching")
raw_data = await self.fetch()
task.total_records = len(raw_data)
await db.commit()
await self.set_phase("transforming")
data = self.transform(raw_data)
snapshot_id = await self._create_snapshot(db, task_id, data, start_time)
await self.set_phase("saving")
records_count = await self._save_data(db, data, task_id=task_id, snapshot_id=snapshot_id)
task.status = "success"
task.phase = "completed"
task.records_processed = records_count
task.progress = 100.0
task.completed_at = datetime.utcnow()
await db.commit()
return {
"status": "success",
"task_id": task_id,
"records_processed": records_count,
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
}
except Exception as e:
task.status = "failed"
task.phase = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
if snapshot_id is not None:
snapshot = await db.get(DataSnapshot, snapshot_id)
if snapshot:
snapshot.status = "failed"
snapshot.completed_at = datetime.utcnow()
snapshot.summary = {"error": str(e)}
await db.commit()
return {
"status": "failed",
"task_id": task_id,
"error": str(e),
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
}
async def _save_data(
self,
db: AsyncSession,
data: List[Dict[str, Any]],
task_id: Optional[int] = None,
snapshot_id: Optional[int] = None,
) -> int:
"""Save transformed data to database"""
from app.models.collected_data import CollectedData
from app.models.data_snapshot import DataSnapshot
if not data:
if snapshot_id is not None:
snapshot = await db.get(DataSnapshot, snapshot_id)
if snapshot:
snapshot.record_count = 0
snapshot.summary = {"created": 0, "updated": 0, "unchanged": 0}
snapshot.status = "success"
snapshot.completed_at = datetime.utcnow()
await db.commit()
return 0
collected_at = datetime.utcnow()
records_added = 0
created_count = 0
updated_count = 0
unchanged_count = 0
seen_entity_keys: set[str] = set()
previous_current_keys: set[str] = set()
previous_current_result = await db.execute(
select(CollectedData.entity_key).where(
CollectedData.source == self.name,
CollectedData.is_current == True,
)
)
previous_current_keys = {row[0] for row in previous_current_result.fetchall() if row[0]}
for i, item in enumerate(data):
print(
f"DEBUG: Saving item {i}: name={item.get('name')}, metadata={item.get('metadata', 'NOT FOUND')}"
)
raw_metadata = item.get("metadata", {})
extra_data = build_dynamic_metadata(
raw_metadata,
country=item.get("country"),
city=item.get("city"),
latitude=item.get("latitude"),
longitude=item.get("longitude"),
value=item.get("value"),
unit=item.get("unit"),
)
normalized_country = normalize_country(item.get("country"))
if normalized_country is not None:
extra_data["country"] = normalized_country
if item.get("country") and normalized_country != item.get("country"):
extra_data["raw_country"] = item.get("country")
if normalized_country is None:
extra_data["country_validation"] = "invalid"
source_id = item.get("source_id") or item.get("id")
reference_date = (
self._parse_reference_date(item.get("reference_date"))
)
source_id_str = str(source_id) if source_id is not None else None
entity_key = f"{self.name}:{source_id_str}" if source_id_str else f"{self.name}:{i}"
previous_record = None
if entity_key and entity_key not in seen_entity_keys:
result = await db.execute(
select(CollectedData)
.where(
CollectedData.source == self.name,
CollectedData.entity_key == entity_key,
CollectedData.is_current == True,
)
.order_by(CollectedData.collected_at.desc().nullslast(), CollectedData.id.desc())
)
previous_records = result.scalars().all()
if previous_records:
previous_record = previous_records[0]
for old_record in previous_records:
old_record.is_current = False
record = CollectedData(
snapshot_id=snapshot_id,
task_id=task_id,
source=self.name,
source_id=source_id_str,
entity_key=entity_key,
data_type=self.data_type,
name=item.get("name"),
title=item.get("title"),
description=item.get("description"),
extra_data=extra_data,
collected_at=collected_at,
reference_date=reference_date,
is_valid=1,
is_current=True,
previous_record_id=previous_record.id if previous_record else None,
deleted_at=None,
)
if previous_record is None:
record.change_type = "created"
record.change_summary = {}
created_count += 1
else:
previous_payload = self._build_comparable_payload(previous_record)
current_payload = self._build_comparable_payload(record)
if current_payload == previous_payload:
record.change_type = "unchanged"
record.change_summary = {}
unchanged_count += 1
else:
changed_fields = [
key for key in current_payload.keys() if current_payload[key] != previous_payload.get(key)
]
record.change_type = "updated"
record.change_summary = {"changed_fields": changed_fields}
updated_count += 1
db.add(record)
seen_entity_keys.add(entity_key)
records_added += 1
if i % 100 == 0:
self.update_progress(i + 1)
await db.commit()
if snapshot_id is not None:
deleted_keys = previous_current_keys - seen_entity_keys
await db.execute(
text(
"""
UPDATE collected_data
SET is_current = FALSE
WHERE source = :source
AND snapshot_id IS DISTINCT FROM :snapshot_id
AND COALESCE(is_current, TRUE) = TRUE
"""
),
{"source": self.name, "snapshot_id": snapshot_id},
)
snapshot = await db.get(DataSnapshot, snapshot_id)
if snapshot:
snapshot.record_count = records_added
snapshot.status = "success"
snapshot.completed_at = datetime.utcnow()
snapshot.summary = {
"created": created_count,
"updated": updated_count,
"unchanged": unchanged_count,
"deleted": len(deleted_keys),
}
await db.commit()
self.update_progress(len(data))
return records_added
async def save(self, db: AsyncSession, data: List[Dict[str, Any]]) -> int:
"""Save data to database (legacy method, use _save_data instead)"""
return await self._save_data(db, data)
class HTTPCollector(BaseCollector):
"""Base class for HTTP API collectors"""
base_url: str = ""
headers: Dict[str, str] = {}
async def fetch(self) -> List[Dict[str, Any]]:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(self.base_url, headers=self.headers)
response.raise_for_status()
return self.parse_response(response.json())
@abstractmethod
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
pass
class IntervalCollector(BaseCollector):
"""Base class for collectors that run on intervals"""
async def run(self, db: AsyncSession) -> Dict[str, Any]:
return await super().run(db)
async def log_task(
db: AsyncSession,
datasource_id: int,
status: str,
records_processed: int = 0,
error_message: Optional[str] = None,
):
"""Log collection task to database"""
from app.models.task import CollectionTask
task = CollectionTask(
datasource_id=datasource_id,
status=status,
records_processed=records_processed,
error_message=error_message,
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
)
db.add(task)
await db.commit()