246 lines
8.0 KiB
Python
246 lines
8.0 KiB
Python
"""Tests for authentication layer.
|
|
|
|
This module tests JWT validation and authentication dependencies.
|
|
"""
|
|
|
|
import time
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from fastapi.testclient import TestClient
|
|
from jose import jwt
|
|
|
|
from app.api.deps import get_current_user, validate_supabase_token
|
|
from app.core.config import get_settings
|
|
from app.main import app
|
|
from app.schemas.token import TokenPayload, UserContext
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
class TestJWTValidation:
|
|
"""Test cases for JWT token validation."""
|
|
|
|
def create_test_token(
|
|
self,
|
|
secret: str = None,
|
|
expired: bool = False,
|
|
wrong_audience: bool = False,
|
|
missing_claims: bool = False,
|
|
algorithm: str = "HS256",
|
|
) -> str:
|
|
"""Create a test JWT token with specified properties.
|
|
|
|
Args:
|
|
secret: JWT signing secret (defaults to settings.SUPABASE_JWT_SECRET)
|
|
expired: Whether the token should be expired
|
|
wrong_audience: Whether to use wrong audience
|
|
missing_claims: Whether to omit required claims
|
|
algorithm: Signing algorithm
|
|
|
|
Returns:
|
|
Encoded JWT string
|
|
"""
|
|
secret = secret or settings.SUPABASE_JWT_SECRET or settings.SECRET_KEY
|
|
|
|
# Base payload
|
|
payload = {
|
|
"sub": "550e8400-e29b-41d4-a716-446655440000",
|
|
"email": "test@wealthwise.app",
|
|
"aud": "wrong-audience" if wrong_audience else "authenticated",
|
|
"role": "authenticated",
|
|
"app_metadata": {},
|
|
"user_metadata": {},
|
|
}
|
|
|
|
if not missing_claims:
|
|
# Set expiration
|
|
if expired:
|
|
payload["exp"] = int(time.time()) - 3600 # 1 hour ago
|
|
else:
|
|
payload["exp"] = int(time.time()) + 3600 # 1 hour from now
|
|
|
|
return jwt.encode(payload, secret, algorithm=algorithm)
|
|
|
|
def test_validate_valid_token(self):
|
|
"""Test validation of a valid JWT token."""
|
|
token = self.create_test_token()
|
|
|
|
payload = validate_supabase_token(token)
|
|
|
|
assert isinstance(payload, TokenPayload)
|
|
assert payload.sub == "550e8400-e29b-41d4-a716-446655440000"
|
|
assert payload.email == "test@wealthwise.app"
|
|
assert payload.aud == "authenticated"
|
|
|
|
def test_validate_expired_token(self):
|
|
"""Test that expired tokens are rejected."""
|
|
token = self.create_test_token(expired=True)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
validate_supabase_token(token)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Could not validate credentials" in exc_info.value.detail
|
|
|
|
def test_validate_wrong_audience(self):
|
|
"""Test that tokens with wrong audience are rejected."""
|
|
token = self.create_test_token(wrong_audience=True)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
validate_supabase_token(token)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
|
|
def test_validate_invalid_signature(self):
|
|
"""Test that tokens with invalid signature are rejected."""
|
|
token = self.create_test_token(secret="wrong-secret")
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
validate_supabase_token(token)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
|
|
def test_validate_missing_exp(self):
|
|
"""Test that tokens without expiration are rejected."""
|
|
token = self.create_test_token(missing_claims=True)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
validate_supabase_token(token)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
|
|
|
|
class TestGetCurrentUser:
|
|
"""Test cases for get_current_user dependency."""
|
|
|
|
def test_get_current_user_valid_token(self):
|
|
"""Test that valid token returns UserContext."""
|
|
# Create a valid token
|
|
payload = {
|
|
"sub": "550e8400-e29b-41d4-a716-446655440000",
|
|
"email": "user@wealthwise.app",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) + 3600,
|
|
"role": "authenticated",
|
|
"app_metadata": {},
|
|
"user_metadata": {},
|
|
}
|
|
token = jwt.encode(
|
|
payload,
|
|
settings.SUPABASE_JWT_SECRET or settings.SECRET_KEY,
|
|
algorithm="HS256",
|
|
)
|
|
|
|
# Since get_current_user is async, we need to run it in an event loop
|
|
import asyncio
|
|
|
|
async def test():
|
|
return await get_current_user(token)
|
|
|
|
user = asyncio.run(test())
|
|
|
|
assert isinstance(user, UserContext)
|
|
assert user.id == "550e8400-e29b-41d4-a716-446655440000"
|
|
assert user.email == "user@wealthwise.app"
|
|
assert user.role == "authenticated"
|
|
|
|
def test_get_current_user_no_token(self):
|
|
"""Test that missing token raises 401."""
|
|
import asyncio
|
|
|
|
async def test():
|
|
return await get_current_user(None)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(test())
|
|
|
|
assert exc_info.value.status_code == 401
|
|
|
|
def test_get_current_user_invalid_token(self):
|
|
"""Test that invalid token raises 401."""
|
|
import asyncio
|
|
|
|
async def test():
|
|
return await get_current_user("invalid-token")
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(test())
|
|
|
|
assert exc_info.value.status_code == 401
|
|
|
|
|
|
class TestUserContext:
|
|
"""Test cases for UserContext model."""
|
|
|
|
def test_from_token_payload(self):
|
|
"""Test conversion from TokenPayload to UserContext."""
|
|
payload = TokenPayload(
|
|
sub="550e8400-e29b-41d4-a716-446655440000",
|
|
email="test@wealthwise.app",
|
|
aud="authenticated",
|
|
role="authenticated",
|
|
app_metadata={"role": "admin"},
|
|
)
|
|
|
|
context = UserContext.from_token_payload(payload)
|
|
|
|
assert context.id == "550e8400-e29b-41d4-a716-446655440000"
|
|
assert context.email == "test@wealthwise.app"
|
|
assert context.role == "admin" # From app_metadata
|
|
|
|
def test_from_token_payload_no_app_metadata_role(self):
|
|
"""Test fallback to role claim when app_metadata has no role."""
|
|
payload = TokenPayload(
|
|
sub="550e8400-e29b-41d4-a716-446655440000",
|
|
email="test@wealthwise.app",
|
|
aud="authenticated",
|
|
role="authenticated",
|
|
app_metadata={},
|
|
)
|
|
|
|
context = UserContext.from_token_payload(payload)
|
|
|
|
assert context.role == "authenticated" # From role claim
|
|
|
|
|
|
class TestProtectedEndpoints:
|
|
"""Integration tests for protected endpoints."""
|
|
|
|
def test_me_endpoint_without_auth(self):
|
|
"""Test that /me endpoint requires authentication."""
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/api/v1/users/me")
|
|
|
|
assert response.status_code == 401
|
|
|
|
def test_me_endpoint_with_valid_auth(self):
|
|
"""Test that /me endpoint works with valid token."""
|
|
# Create valid token
|
|
payload = {
|
|
"sub": "550e8400-e29b-41d4-a716-446655440000",
|
|
"email": "user@wealthwise.app",
|
|
"aud": "authenticated",
|
|
"exp": int(time.time()) + 3600,
|
|
"role": "authenticated",
|
|
"app_metadata": {},
|
|
"user_metadata": {},
|
|
}
|
|
token = jwt.encode(
|
|
payload,
|
|
settings.SUPABASE_JWT_SECRET or settings.SECRET_KEY,
|
|
algorithm="HS256",
|
|
)
|
|
|
|
client = TestClient(app)
|
|
response = client.get(
|
|
"/api/v1/users/me",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == "550e8400-e29b-41d4-a716-446655440000"
|
|
assert data["email"] == "user@wealthwise.app" |