180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
"""Base collector class for all data sources"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Any, Optional
|
|
from datetime import datetime
|
|
import httpx
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
class BaseCollector(ABC):
|
|
"""Abstract base class for data collectors"""
|
|
|
|
name: str = "base_collector"
|
|
priority: str = "P1"
|
|
module: str = "L1"
|
|
frequency_hours: int = 4
|
|
data_type: str = "generic" # Override in subclass: "supercomputer", "model", "dataset", etc.
|
|
|
|
@abstractmethod
|
|
async def fetch(self) -> List[Dict[str, Any]]:
|
|
"""Fetch raw data from source"""
|
|
pass
|
|
|
|
def transform(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Transform raw data to internal format (default: pass through)"""
|
|
return raw_data
|
|
|
|
async def run(self, db: AsyncSession) -> Dict[str, Any]:
|
|
"""Full pipeline: fetch -> transform -> save"""
|
|
from app.services.collectors.registry import collector_registry
|
|
from app.models.task import CollectionTask
|
|
from app.models.collected_data import CollectedData
|
|
|
|
start_time = datetime.utcnow()
|
|
datasource_id = getattr(self, "_datasource_id", 1) # Default to 1 for built-in collectors
|
|
|
|
# Check if collector is active
|
|
if not collector_registry.is_active(self.name):
|
|
return {"status": "skipped", "reason": "Collector is disabled"}
|
|
|
|
# Log task start
|
|
task = CollectionTask(
|
|
datasource_id=datasource_id,
|
|
status="running",
|
|
started_at=start_time,
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
task_id = task.id
|
|
|
|
try:
|
|
raw_data = await self.fetch()
|
|
data = self.transform(raw_data)
|
|
|
|
# Save data to database
|
|
records_count = await self._save_data(db, data)
|
|
|
|
# Log task success
|
|
task.status = "success"
|
|
task.records_processed = records_count
|
|
task.completed_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
return {
|
|
"status": "success",
|
|
"task_id": task_id,
|
|
"records_processed": records_count,
|
|
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
|
}
|
|
except Exception as e:
|
|
# Log task failure
|
|
task.status = "failed"
|
|
task.error_message = str(e)
|
|
task.completed_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
return {
|
|
"status": "failed",
|
|
"task_id": task_id,
|
|
"error": str(e),
|
|
"execution_time_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
|
}
|
|
|
|
async def _save_data(self, db: AsyncSession, data: List[Dict[str, Any]]) -> int:
|
|
"""Save transformed data to database"""
|
|
from app.models.collected_data import CollectedData
|
|
|
|
if not data:
|
|
return 0
|
|
|
|
collected_at = datetime.utcnow()
|
|
records_added = 0
|
|
|
|
for item in data:
|
|
# Create CollectedData entry
|
|
record = CollectedData(
|
|
source=self.name,
|
|
source_id=item.get("source_id") or item.get("id"),
|
|
data_type=self.data_type,
|
|
name=item.get("name"),
|
|
title=item.get("title"),
|
|
description=item.get("description"),
|
|
country=item.get("country"),
|
|
city=item.get("city"),
|
|
latitude=str(item.get("latitude", ""))
|
|
if item.get("latitude") is not None
|
|
else None,
|
|
longitude=str(item.get("longitude", ""))
|
|
if item.get("longitude") is not None
|
|
else None,
|
|
value=item.get("value"),
|
|
unit=item.get("unit"),
|
|
extra_data=item.get("metadata", {}),
|
|
collected_at=collected_at,
|
|
reference_date=datetime.fromisoformat(
|
|
item.get("reference_date").replace("Z", "+00:00")
|
|
)
|
|
if item.get("reference_date")
|
|
else None,
|
|
is_valid=1,
|
|
)
|
|
db.add(record)
|
|
records_added += 1
|
|
|
|
await db.commit()
|
|
return records_added
|
|
|
|
async def save(self, db: AsyncSession, data: List[Dict[str, Any]]) -> int:
|
|
"""Save data to database (legacy method, use _save_data instead)"""
|
|
return await self._save_data(db, data)
|
|
|
|
|
|
class HTTPCollector(BaseCollector):
|
|
"""Base class for HTTP API collectors"""
|
|
|
|
base_url: str = ""
|
|
headers: Dict[str, str] = {}
|
|
|
|
async def fetch(self) -> List[Dict[str, Any]]:
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
response = await client.get(self.base_url, headers=self.headers)
|
|
response.raise_for_status()
|
|
return self.parse_response(response.json())
|
|
|
|
@abstractmethod
|
|
def parse_response(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
pass
|
|
|
|
|
|
class IntervalCollector(BaseCollector):
|
|
"""Base class for collectors that run on intervals"""
|
|
|
|
async def run(self, db: AsyncSession) -> Dict[str, Any]:
|
|
return await super().run(db)
|
|
|
|
|
|
async def log_task(
|
|
db: AsyncSession,
|
|
datasource_id: int,
|
|
status: str,
|
|
records_processed: int = 0,
|
|
error_message: Optional[str] = None,
|
|
):
|
|
"""Log collection task to database"""
|
|
from app.models.task import CollectionTask
|
|
|
|
task = CollectionTask(
|
|
datasource_id=datasource_id,
|
|
status=status,
|
|
records_processed=records_processed,
|
|
error_message=error_message,
|
|
started_at=datetime.utcnow(),
|
|
completed_at=datetime.utcnow(),
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|