first commit
This commit is contained in:
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test configuration"""
|
||||
BIN
backend/tests/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
backend/tests/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/conftest.cpython-311-pytest-9.0.2.pyc
Normal file
BIN
backend/tests/__pycache__/conftest.cpython-311-pytest-9.0.2.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/test_api.cpython-311-pytest-9.0.2.pyc
Normal file
BIN
backend/tests/__pycache__/test_api.cpython-311-pytest-9.0.2.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
103
backend/tests/conftest.py
Normal file
103
backend/tests/conftest.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Pytest configuration and fixtures"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests"""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session"""
|
||||
session = AsyncMock(spec=AsyncSession)
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.close = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_top500_response():
|
||||
"""Sample TOP500 API response"""
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"rank": 1,
|
||||
"system_name": "Frontier",
|
||||
"country": "USA",
|
||||
"city": "Oak Ridge",
|
||||
"latitude": 35.9322,
|
||||
"longitude": -84.3108,
|
||||
"manufacturer": "HPE",
|
||||
"r_max": 1102000.0,
|
||||
"r_peak": 1685000.0,
|
||||
"power": 21510.0,
|
||||
"cores": 8730112,
|
||||
"interconnect": "Slingshot 11",
|
||||
"os": "CentOS",
|
||||
},
|
||||
{
|
||||
"rank": 2,
|
||||
"system_name": "Fugaku",
|
||||
"country": "Japan",
|
||||
"city": "Kobe",
|
||||
"latitude": 34.6913,
|
||||
"longitude": 135.1830,
|
||||
"manufacturer": "Fujitsu",
|
||||
"r_max": 442010.0,
|
||||
"r_peak": 537212.0,
|
||||
"power": 29899.0,
|
||||
"cores": 7630848,
|
||||
"interconnect": "Tofu interconnect D",
|
||||
"os": "RHEL",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_huggingface_response():
|
||||
"""Sample Hugging Face API response"""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"id": "bert-base-uncased",
|
||||
"author": "google",
|
||||
"description": "BERT base model",
|
||||
"likes": 25000,
|
||||
"downloads": 5000000,
|
||||
"language": "en",
|
||||
"tags": ["transformer", "bert"],
|
||||
"pipeline_tag": "feature-extraction",
|
||||
"library_name": "transformers",
|
||||
"createdAt": "2024-01-15T10:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_alert_data():
|
||||
"""Sample alert data"""
|
||||
return {
|
||||
"id": 1,
|
||||
"severity": "warning",
|
||||
"status": "active",
|
||||
"datasource_id": 2,
|
||||
"datasource_name": "Epoch AI",
|
||||
"message": "API response time > 30s",
|
||||
"created_at": "2024-01-20T09:30:00Z",
|
||||
"acknowledged_by": None,
|
||||
}
|
||||
108
backend/tests/test_api.py
Normal file
108
backend/tests/test_api.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""API endpoint tests"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from app.main import app
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
"""Create authentication headers"""
|
||||
token = create_access_token({"sub": "1", "username": "testuser"})
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check():
|
||||
"""Test health check endpoint"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "version" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == settings.PROJECT_NAME
|
||||
assert data["version"] == settings.VERSION
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_stats_without_auth():
|
||||
"""Test dashboard stats requires authentication"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/dashboard/stats")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_stats_with_auth(auth_headers):
|
||||
"""Test dashboard stats with authentication"""
|
||||
with patch("app.api.v1.dashboard.cache.get", return_value=None):
|
||||
with patch("app.api.v1.dashboard.cache.set", return_value=True):
|
||||
with patch("app.db.session.get_db") as mock_get_db:
|
||||
mock_session = AsyncMock()
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalar.return_value = 0
|
||||
mock_result.fetchall.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
async def mock_db_context():
|
||||
yield mock_session
|
||||
|
||||
mock_get_db.return_value = mock_db_context()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/v1/dashboard/stats",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_datasources" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alerts_without_auth():
|
||||
"""Test alerts endpoint requires authentication"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/alerts")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alerts_endpoint_with_auth(auth_headers):
|
||||
"""Test alerts endpoint with authentication"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/alerts", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token():
|
||||
"""Test that invalid token is rejected"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/v1/dashboard/stats",
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
112
backend/tests/test_collectors.py
Normal file
112
backend/tests/test_collectors.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Unit tests for data collectors"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.collectors.top500 import TOP500Collector
|
||||
from app.services.collectors.base import BaseCollector, HTTPCollector
|
||||
|
||||
|
||||
class TestBaseCollector:
|
||||
"""Tests for BaseCollector"""
|
||||
|
||||
def test_base_collector_attributes(self):
|
||||
"""Test base collector has correct default attributes via concrete class"""
|
||||
collector = TOP500Collector()
|
||||
assert collector.name == "top500"
|
||||
assert collector.priority == "P0"
|
||||
assert collector.module == "L1"
|
||||
assert collector.frequency_hours == 4
|
||||
|
||||
|
||||
class TestTOP500Collector:
|
||||
"""Tests for TOP500Collector"""
|
||||
|
||||
def test_parse_coordinate_valid_float(self):
|
||||
"""Test parsing valid float coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate(45.5) == 45.5
|
||||
|
||||
def test_parse_coordinate_valid_string(self):
|
||||
"""Test parsing valid string coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate("45.5") == 45.5
|
||||
|
||||
def test_parse_coordinate_invalid_string(self):
|
||||
"""Test parsing invalid string coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate("invalid") == 0.0
|
||||
|
||||
def test_parse_coordinate_none(self):
|
||||
"""Test parsing None coordinate"""
|
||||
collector = TOP500Collector()
|
||||
assert collector._parse_coordinate(None) == 0.0
|
||||
|
||||
def test_parse_response_empty(self):
|
||||
"""Test parsing empty response"""
|
||||
collector = TOP500Collector()
|
||||
result = collector.parse_response({"items": []})
|
||||
assert result == []
|
||||
|
||||
def test_parse_response_single_item(self):
|
||||
"""Test parsing single item response"""
|
||||
collector = TOP500Collector()
|
||||
response = {
|
||||
"items": [
|
||||
{
|
||||
"rank": 1,
|
||||
"system_name": "Test Supercomputer",
|
||||
"country": "USA",
|
||||
"city": "San Francisco",
|
||||
"latitude": 37.7749,
|
||||
"longitude": -122.4194,
|
||||
"manufacturer": "Test Corp",
|
||||
"r_max": 100000.0,
|
||||
"r_peak": 150000.0,
|
||||
"power": 5000.0,
|
||||
"cores": 100000,
|
||||
"interconnect": "InfiniBand",
|
||||
"os": "Linux",
|
||||
}
|
||||
]
|
||||
}
|
||||
result = collector.parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0]["cluster_id"] == "top500_1"
|
||||
assert result[0]["name"] == "Test Supercomputer"
|
||||
assert result[0]["country"] == "USA"
|
||||
assert result[0]["rank"] == 1
|
||||
assert result[0]["source"] == "TOP500"
|
||||
|
||||
def test_parse_response_skips_invalid_item(self):
|
||||
"""Test parsing skips items with missing data"""
|
||||
collector = TOP500Collector()
|
||||
response = {
|
||||
"items": [
|
||||
{"rank": 1, "system_name": "Valid"},
|
||||
{"rank": None, "system_name": "Invalid"},
|
||||
]
|
||||
}
|
||||
result = collector.parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Valid"
|
||||
|
||||
|
||||
class TestHTTPCollector:
|
||||
"""Tests for HTTPCollector"""
|
||||
|
||||
def test_http_collector_attributes(self):
|
||||
"""Test HTTP collector has correct default attributes via concrete class"""
|
||||
collector = TOP500Collector()
|
||||
assert collector.base_url == "https://top500.org/api/v1.0/lists/"
|
||||
assert collector.name == "top500"
|
||||
assert collector.priority == "P0"
|
||||
|
||||
def test_collector_has_required_methods(self):
|
||||
"""Test HTTP collector has required methods"""
|
||||
collector = TOP500Collector()
|
||||
assert hasattr(collector, "fetch")
|
||||
assert hasattr(collector, "parse_response")
|
||||
assert callable(collector.fetch)
|
||||
assert callable(collector.parse_response)
|
||||
131
backend/tests/test_models.py
Normal file
131
backend/tests/test_models.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Unit tests for models"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.alert import Alert, AlertSeverity, AlertStatus
|
||||
from app.models.task import CollectionTask
|
||||
|
||||
|
||||
class TestUserModel:
|
||||
"""Tests for User model"""
|
||||
|
||||
def test_user_creation(self):
|
||||
"""Test user model creation"""
|
||||
user = User(
|
||||
id=1,
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
role="admin",
|
||||
is_active=True,
|
||||
)
|
||||
assert user.id == 1
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "test@example.com"
|
||||
assert user.is_active is True
|
||||
|
||||
def test_user_role_assignment(self):
|
||||
"""Test user role assignment"""
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
password_hash="hashed",
|
||||
role="admin",
|
||||
)
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_user_password_hash(self):
|
||||
"""Test user password hash attribute"""
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
)
|
||||
assert user.password_hash == "hashed_password"
|
||||
|
||||
|
||||
class TestAlertModel:
|
||||
"""Tests for Alert model"""
|
||||
|
||||
def test_alert_creation(self):
|
||||
"""Test alert model creation"""
|
||||
alert = Alert(
|
||||
id=1,
|
||||
severity=AlertSeverity.WARNING,
|
||||
status=AlertStatus.ACTIVE,
|
||||
message="Test alert message",
|
||||
datasource_id=1,
|
||||
datasource_name="Test Source",
|
||||
)
|
||||
assert alert.id == 1
|
||||
assert alert.severity == AlertSeverity.WARNING
|
||||
assert alert.status == AlertStatus.ACTIVE
|
||||
assert alert.message == "Test alert message"
|
||||
|
||||
def test_alert_to_dict(self):
|
||||
"""Test alert to_dict method"""
|
||||
alert = Alert(
|
||||
id=1,
|
||||
severity=AlertSeverity.CRITICAL,
|
||||
status=AlertStatus.ACTIVE,
|
||||
message="Critical alert",
|
||||
datasource_id=2,
|
||||
datasource_name="Test Source",
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
)
|
||||
result = alert.to_dict()
|
||||
assert result["id"] == 1
|
||||
assert result["severity"] == "critical"
|
||||
assert result["status"] == "active"
|
||||
assert result["message"] == "Critical alert"
|
||||
assert result["created_at"] == "2024-01-01T12:00:00"
|
||||
|
||||
def test_alert_severity_enum(self):
|
||||
"""Test alert severity enum values"""
|
||||
assert AlertSeverity.CRITICAL.value == "critical"
|
||||
assert AlertSeverity.WARNING.value == "warning"
|
||||
assert AlertSeverity.INFO.value == "info"
|
||||
|
||||
def test_alert_status_enum(self):
|
||||
"""Test alert status enum values"""
|
||||
assert AlertStatus.ACTIVE.value == "active"
|
||||
assert AlertStatus.ACKNOWLEDGED.value == "acknowledged"
|
||||
assert AlertStatus.RESOLVED.value == "resolved"
|
||||
|
||||
|
||||
class TestCollectionTaskModel:
|
||||
"""Tests for CollectionTask model"""
|
||||
|
||||
def test_task_creation(self):
|
||||
"""Test collection task creation"""
|
||||
task = CollectionTask(
|
||||
id=1,
|
||||
datasource_id=1,
|
||||
status="running",
|
||||
records_processed=0,
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
assert task.id == 1
|
||||
assert task.datasource_id == 1
|
||||
assert task.status == "running"
|
||||
|
||||
def test_task_with_records(self):
|
||||
"""Test collection task with records processed"""
|
||||
task = CollectionTask(
|
||||
datasource_id=1,
|
||||
status="success",
|
||||
records_processed=100,
|
||||
)
|
||||
assert task.records_processed == 100
|
||||
|
||||
def test_task_error_message(self):
|
||||
"""Test collection task with error message"""
|
||||
task = CollectionTask(
|
||||
datasource_id=1,
|
||||
status="failed",
|
||||
error_message="Connection timeout",
|
||||
)
|
||||
assert task.error_message == "Connection timeout"
|
||||
113
backend/tests/test_security.py
Normal file
113
backend/tests/test_security.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Unit tests for security module"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from jose import jwt
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Tests for password hashing functions"""
|
||||
|
||||
def test_hash_password(self):
|
||||
"""Test password hashing"""
|
||||
password = "test_password_123"
|
||||
hashed = get_password_hash(password)
|
||||
assert hashed != password
|
||||
assert len(hashed) > 0
|
||||
|
||||
def test_verify_correct_password(self):
|
||||
"""Test verification of correct password"""
|
||||
password = "test_password_123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_incorrect_password(self):
|
||||
"""Test verification of incorrect password"""
|
||||
password = "test_password_123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password("wrong_password", hashed) is False
|
||||
|
||||
def test_hash_is_unique(self):
|
||||
"""Test that hashes are unique for same password"""
|
||||
password = "test_password_123"
|
||||
hash1 = get_password_hash(password)
|
||||
hash2 = get_password_hash(password)
|
||||
assert hash1 != hash2 # bcrypt adds salt
|
||||
|
||||
|
||||
class TestTokenCreation:
|
||||
"""Tests for token creation functions"""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test access token creation"""
|
||||
data = {"sub": "123", "username": "testuser"}
|
||||
token = create_access_token(data)
|
||||
assert token is not None
|
||||
assert len(token) > 0
|
||||
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert payload["sub"] == "123"
|
||||
assert payload["username"] == "testuser"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_create_refresh_token(self):
|
||||
"""Test refresh token creation"""
|
||||
data = {"sub": "123"}
|
||||
token = create_refresh_token(data)
|
||||
assert token is not None
|
||||
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert payload["sub"] == "123"
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_access_token_expiration(self):
|
||||
"""Test access token has correct expiration"""
|
||||
data = {"sub": "123"}
|
||||
token = create_access_token(data)
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
exp_timestamp = payload["exp"]
|
||||
# Token should expire in approximately 15 minutes (accounting for timezone)
|
||||
expected_minutes = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
# The timestamp is in seconds since epoch
|
||||
import time
|
||||
|
||||
now_timestamp = time.time()
|
||||
minutes_diff = (exp_timestamp - now_timestamp) / 60
|
||||
assert expected_minutes - 1 < minutes_diff < expected_minutes + 1
|
||||
|
||||
def test_refresh_token_expiration(self):
|
||||
"""Test refresh token has correct expiration"""
|
||||
data = {"sub": "123"}
|
||||
token = create_refresh_token(data)
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
exp = datetime.fromtimestamp(payload["exp"])
|
||||
now = datetime.utcnow()
|
||||
# Token should expire in approximately 7 days (with some tolerance)
|
||||
delta = exp - now
|
||||
assert delta.days >= 6 # At least 6 days
|
||||
assert delta.days <= 8 # Less than 8 days
|
||||
|
||||
|
||||
class TestJWTSecurity:
|
||||
"""Tests for JWT security features"""
|
||||
|
||||
def test_invalid_token_raises_error(self):
|
||||
"""Test that invalid token raises JWTError"""
|
||||
with pytest.raises(jwt.JWTError):
|
||||
jwt.decode("invalid_token", settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
|
||||
def test_token_with_wrong_secret_raises_error(self):
|
||||
"""Test that token with wrong secret raises error"""
|
||||
data = {"sub": "123"}
|
||||
token = create_access_token(data)
|
||||
with pytest.raises(jwt.JWTError):
|
||||
jwt.decode(token, "wrong_secret", algorithms=[settings.ALGORITHM])
|
||||
Reference in New Issue
Block a user