414 lines
15 KiB
Python
414 lines
15 KiB
Python
"""Base collector class for all data sources"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Any, Optional
|
|
from datetime import datetime
|
|
import httpx
|
|
from sqlalchemy import select, text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.collected_data_fields import build_dynamic_metadata, get_record_field
|
|
from app.core.config import settings
|
|
from app.core.countries import normalize_country
|
|
|
|
|
|
class BaseCollector(ABC):
|
|
"""Abstract base class for data collectors"""
|
|
|
|
name: str = "base_collector"
|
|
priority: str = "P1"
|
|
module: str = "L1"
|
|
frequency_hours: int = 4
|
|
data_type: str = "generic"
|
|
|
|
def __init__(self):
|
|
self._current_task = None
|
|
self._db_session = None
|
|
self._datasource_id = 1
|
|
self._resolved_url: Optional[str] = None
|
|
|
|
async def resolve_url(self, db: AsyncSession) -> None:
|
|
from app.core.data_sources import get_data_sources_config
|
|
|
|
config = get_data_sources_config()
|
|
self._resolved_url = await config.get_url(self.name, db)
|
|
|
|
def update_progress(self, records_processed: int):
|
|
"""Update task progress - call this during data processing"""
|
|
if self._current_task and self._db_session and self._current_task.total_records > 0:
|
|
self._current_task.records_processed = records_processed
|
|
self._current_task.progress = (
|
|
records_processed / self._current_task.total_records
|
|
) * 100
|
|
|
|
async def set_phase(self, phase: str):
|
|
if self._current_task and self._db_session:
|
|
self._current_task.phase = phase
|
|
await self._db_session.commit()
|
|
|
|
@abstractmethod
|
|
async def fetch(self) -> List[Dict[str, Any]]:
|
|
"""Fetch raw data from source"""
|
|
pass
|
|
|
|
def transform(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Transform raw data to internal format (default: pass through)"""
|
|
return raw_data
|
|
|
|
def _parse_reference_date(self, value: Any) -> Optional[datetime]:
|
|
if not value:
|
|
return None
|
|
if isinstance(value, datetime):
|
|
return value
|
|
if isinstance(value, str):
|
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
return None
|
|
|
|
def _build_comparable_payload(self, record: Any) -> Dict[str, Any]:
|
|
return {
|
|
"name": getattr(record, "name", None),
|
|
"title": getattr(record, "title", None),
|
|
"description": getattr(record, "description", None),
|
|
"country": get_record_field(record, "country"),
|
|
"city": get_record_field(record, "city"),
|
|
"latitude": get_record_field(record, "latitude"),
|
|
"longitude": get_record_field(record, "longitude"),
|
|
"value": get_record_field(record, "value"),
|
|
"unit": get_record_field(record, "unit"),
|
|
"metadata": getattr(record, "extra_data", None) or {},
|
|
"reference_date": (
|
|
getattr(record, "reference_date", None).isoformat()
|
|
if getattr(record, "reference_date", None)
|
|
else None
|
|
),
|
|
}
|
|
|
|
async def _create_snapshot(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: int,
|
|
data: List[Dict[str, Any]],
|
|
started_at: datetime,
|
|
) -> int:
|
|
from app.models.data_snapshot import DataSnapshot
|
|
|
|
reference_dates = [
|
|
parsed
|
|
for parsed in (self._parse_reference_date(item.get("reference_date")) for item in data)
|
|
if parsed is not None
|
|
]
|
|
reference_date = max(reference_dates) if reference_dates else None
|
|
|
|
result = await db.execute(
|
|
select(DataSnapshot)
|
|
.where(DataSnapshot.source == self.name, DataSnapshot.is_current == True)
|
|
.order_by(DataSnapshot.completed_at.desc().nullslast(), DataSnapshot.id.desc())
|
|
.limit(1)
|
|
)
|
|
previous_snapshot = result.scalar_one_or_none()
|
|
|
|
snapshot = DataSnapshot(
|
|
datasource_id=getattr(self, "_datasource_id", 1),
|
|
task_id=task_id,
|
|
source=self.name,
|
|
snapshot_key=f"{self.name}:{task_id}",
|
|
reference_date=reference_date,
|
|
started_at=started_at,
|
|
status="running",
|
|
is_current=True,
|
|
parent_snapshot_id=previous_snapshot.id if previous_snapshot else None,
|
|
summary={},
|
|
)
|
|
db.add(snapshot)
|
|
|
|
if previous_snapshot:
|
|
previous_snapshot.is_current = False
|
|
|
|
await db.commit()
|
|
return snapshot.id
|
|
|
|
async def run(self, db: AsyncSession) -> Dict[str, Any]:
|
|
"""Full pipeline: fetch -> transform -> save"""
|
|
from app.services.collectors.registry import collector_registry
|
|
from app.models.task import CollectionTask
|
|
from app.models.data_snapshot import DataSnapshot
|
|
|
|
start_time = datetime.utcnow()
|
|
datasource_id = getattr(self, "_datasource_id", 1)
|
|
snapshot_id: Optional[int] = None
|
|
|
|
if not collector_registry.is_active(self.name):
|
|
return {"status": "skipped", "reason": "Collector is disabled"}
|
|
|
|
task = CollectionTask(
|
|
datasource_id=datasource_id,
|
|
status="running",
|
|
phase="queued",
|
|
started_at=start_time,
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
task_id = task.id
|
|
|
|
self._current_task = task
|
|
self._db_session = db
|
|
|
|
await self.resolve_url(db)
|
|
|
|
try:
|
|
await self.set_phase("fetching")
|
|
raw_data = await self.fetch()
|
|
task.total_records = len(raw_data)
|
|
await db.commit()
|
|
|
|
await self.set_phase("transforming")
|
|
data = self.transform(raw_data)
|
|
snapshot_id = await self._create_snapshot(db, task_id, data, start_time)
|
|
|
|
await self.set_phase("saving")
|
|
records_count = await self._save_data(db, data, task_id=task_id, snapshot_id=snapshot_id)
|
|
|
|
task.status = "success"
|
|
task.phase = "completed"
|
|
task.records_processed = records_count
|
|
task.progress = 100.0
|
|
task.completed_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
return {
|
|
"status": "success",
|
|
"task_id": task_id,
|
|
"records_processed": records_count,
|
|
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
|
}
|
|
except Exception as e:
|
|
task.status = "failed"
|
|
task.phase = "failed"
|
|
task.error_message = str(e)
|
|
task.completed_at = datetime.utcnow()
|
|
if snapshot_id is not None:
|
|
snapshot = await db.get(DataSnapshot, snapshot_id)
|
|
if snapshot:
|
|
snapshot.status = "failed"
|
|
snapshot.completed_at = datetime.utcnow()
|
|
snapshot.summary = {"error": str(e)}
|
|
await db.commit()
|
|
|
|
return {
|
|
"status": "failed",
|
|
"task_id": task_id,
|
|
"error": str(e),
|
|
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
|
}
|
|
|
|
async def _save_data(
|
|
self,
|
|
db: AsyncSession,
|
|
data: List[Dict[str, Any]],
|
|
task_id: Optional[int] = None,
|
|
snapshot_id: Optional[int] = None,
|
|
) -> int:
|
|
"""Save transformed data to database"""
|
|
from app.models.collected_data import CollectedData
|
|
from app.models.data_snapshot import DataSnapshot
|
|
|
|
if not data:
|
|
if snapshot_id is not None:
|
|
snapshot = await db.get(DataSnapshot, snapshot_id)
|
|
if snapshot:
|
|
snapshot.record_count = 0
|
|
snapshot.summary = {"created": 0, "updated": 0, "unchanged": 0}
|
|
snapshot.status = "success"
|
|
snapshot.completed_at = datetime.utcnow()
|
|
await db.commit()
|
|
return 0
|
|
|
|
collected_at = datetime.utcnow()
|
|
records_added = 0
|
|
created_count = 0
|
|
updated_count = 0
|
|
unchanged_count = 0
|
|
seen_entity_keys: set[str] = set()
|
|
previous_current_keys: set[str] = set()
|
|
|
|
previous_current_result = await db.execute(
|
|
select(CollectedData.entity_key).where(
|
|
CollectedData.source == self.name,
|
|
CollectedData.is_current == True,
|
|
)
|
|
)
|
|
previous_current_keys = {row[0] for row in previous_current_result.fetchall() if row[0]}
|
|
|
|
for i, item in enumerate(data):
|
|
print(
|
|
f"DEBUG: Saving item {i}: name={item.get('name')}, metadata={item.get('metadata', 'NOT FOUND')}"
|
|
)
|
|
raw_metadata = item.get("metadata", {})
|
|
extra_data = build_dynamic_metadata(
|
|
raw_metadata,
|
|
country=item.get("country"),
|
|
city=item.get("city"),
|
|
latitude=item.get("latitude"),
|
|
longitude=item.get("longitude"),
|
|
value=item.get("value"),
|
|
unit=item.get("unit"),
|
|
)
|
|
normalized_country = normalize_country(item.get("country"))
|
|
if normalized_country is not None:
|
|
extra_data["country"] = normalized_country
|
|
|
|
if item.get("country") and normalized_country != item.get("country"):
|
|
extra_data["raw_country"] = item.get("country")
|
|
if normalized_country is None:
|
|
extra_data["country_validation"] = "invalid"
|
|
|
|
source_id = item.get("source_id") or item.get("id")
|
|
reference_date = (
|
|
self._parse_reference_date(item.get("reference_date"))
|
|
)
|
|
source_id_str = str(source_id) if source_id is not None else None
|
|
entity_key = f"{self.name}:{source_id_str}" if source_id_str else f"{self.name}:{i}"
|
|
previous_record = None
|
|
|
|
if entity_key and entity_key not in seen_entity_keys:
|
|
result = await db.execute(
|
|
select(CollectedData)
|
|
.where(
|
|
CollectedData.source == self.name,
|
|
CollectedData.entity_key == entity_key,
|
|
CollectedData.is_current == True,
|
|
)
|
|
.order_by(CollectedData.collected_at.desc().nullslast(), CollectedData.id.desc())
|
|
)
|
|
previous_records = result.scalars().all()
|
|
if previous_records:
|
|
previous_record = previous_records[0]
|
|
for old_record in previous_records:
|
|
old_record.is_current = False
|
|
|
|
record = CollectedData(
|
|
snapshot_id=snapshot_id,
|
|
task_id=task_id,
|
|
source=self.name,
|
|
source_id=source_id_str,
|
|
entity_key=entity_key,
|
|
data_type=self.data_type,
|
|
name=item.get("name"),
|
|
title=item.get("title"),
|
|
description=item.get("description"),
|
|
extra_data=extra_data,
|
|
collected_at=collected_at,
|
|
reference_date=reference_date,
|
|
is_valid=1,
|
|
is_current=True,
|
|
previous_record_id=previous_record.id if previous_record else None,
|
|
deleted_at=None,
|
|
)
|
|
|
|
if previous_record is None:
|
|
record.change_type = "created"
|
|
record.change_summary = {}
|
|
created_count += 1
|
|
else:
|
|
previous_payload = self._build_comparable_payload(previous_record)
|
|
current_payload = self._build_comparable_payload(record)
|
|
if current_payload == previous_payload:
|
|
record.change_type = "unchanged"
|
|
record.change_summary = {}
|
|
unchanged_count += 1
|
|
else:
|
|
changed_fields = [
|
|
key for key in current_payload.keys() if current_payload[key] != previous_payload.get(key)
|
|
]
|
|
record.change_type = "updated"
|
|
record.change_summary = {"changed_fields": changed_fields}
|
|
updated_count += 1
|
|
|
|
db.add(record)
|
|
seen_entity_keys.add(entity_key)
|
|
records_added += 1
|
|
|
|
if i % 100 == 0:
|
|
self.update_progress(i + 1)
|
|
await db.commit()
|
|
|
|
if snapshot_id is not None:
|
|
deleted_keys = previous_current_keys - seen_entity_keys
|
|
await db.execute(
|
|
text(
|
|
"""
|
|
UPDATE collected_data
|
|
SET is_current = FALSE
|
|
WHERE source = :source
|
|
AND snapshot_id IS DISTINCT FROM :snapshot_id
|
|
AND COALESCE(is_current, TRUE) = TRUE
|
|
"""
|
|
),
|
|
{"source": self.name, "snapshot_id": snapshot_id},
|
|
)
|
|
snapshot = await db.get(DataSnapshot, snapshot_id)
|
|
if snapshot:
|
|
snapshot.record_count = records_added
|
|
snapshot.status = "success"
|
|
snapshot.completed_at = datetime.utcnow()
|
|
snapshot.summary = {
|
|
"created": created_count,
|
|
"updated": updated_count,
|
|
"unchanged": unchanged_count,
|
|
"deleted": len(deleted_keys),
|
|
}
|
|
|
|
await db.commit()
|
|
self.update_progress(len(data))
|
|
return records_added
|
|
|
|
async def save(self, db: AsyncSession, data: List[Dict[str, Any]]) -> int:
|
|
"""Save data to database (legacy method, use _save_data instead)"""
|
|
return await self._save_data(db, data)
|
|
|
|
|
|
class HTTPCollector(BaseCollector):
|
|
"""Base class for HTTP API collectors"""
|
|
|
|
base_url: str = ""
|
|
headers: Dict[str, str] = {}
|
|
|
|
async def fetch(self) -> List[Dict[str, Any]]:
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
response = await client.get(self.base_url, headers=self.headers)
|
|
response.raise_for_status()
|
|
return self.parse_response(response.json())
|
|
|
|
@abstractmethod
|
|
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
pass
|
|
|
|
|
|
class IntervalCollector(BaseCollector):
|
|
"""Base class for collectors that run on intervals"""
|
|
|
|
async def run(self, db: AsyncSession) -> Dict[str, Any]:
|
|
return await super().run(db)
|
|
|
|
|
|
async def log_task(
|
|
db: AsyncSession,
|
|
datasource_id: int,
|
|
status: str,
|
|
records_processed: int = 0,
|
|
error_message: Optional[str] = None,
|
|
):
|
|
"""Log collection task to database"""
|
|
from app.models.task import CollectionTask
|
|
|
|
task = CollectionTask(
|
|
datasource_id=datasource_id,
|
|
status=status,
|
|
records_processed=records_processed,
|
|
error_message=error_message,
|
|
started_at=datetime.utcnow(),
|
|
completed_at=datetime.utcnow(),
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|