Files
WealthWise/backend/tests/test_auth.py

246 lines
8.0 KiB
Python
Raw Normal View History

"""Tests for authentication layer.
This module tests JWT validation and authentication dependencies.
"""
import time
from datetime import datetime, timedelta, timezone
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from jose import jwt
from app.api.deps import get_current_user, validate_supabase_token
from app.core.config import get_settings
from app.main import app
from app.schemas.token import TokenPayload, UserContext
settings = get_settings()
class TestJWTValidation:
"""Test cases for JWT token validation."""
def create_test_token(
self,
secret: str = None,
expired: bool = False,
wrong_audience: bool = False,
missing_claims: bool = False,
algorithm: str = "HS256",
) -> str:
"""Create a test JWT token with specified properties.
Args:
secret: JWT signing secret (defaults to settings.SUPABASE_JWT_SECRET)
expired: Whether the token should be expired
wrong_audience: Whether to use wrong audience
missing_claims: Whether to omit required claims
algorithm: Signing algorithm
Returns:
Encoded JWT string
"""
secret = secret or settings.SUPABASE_JWT_SECRET or settings.SECRET_KEY
# Base payload
payload = {
"sub": "550e8400-e29b-41d4-a716-446655440000",
"email": "test@wealthwise.app",
"aud": "wrong-audience" if wrong_audience else "authenticated",
"role": "authenticated",
"app_metadata": {},
"user_metadata": {},
}
if not missing_claims:
# Set expiration
if expired:
payload["exp"] = int(time.time()) - 3600 # 1 hour ago
else:
payload["exp"] = int(time.time()) + 3600 # 1 hour from now
return jwt.encode(payload, secret, algorithm=algorithm)
def test_validate_valid_token(self):
"""Test validation of a valid JWT token."""
token = self.create_test_token()
payload = validate_supabase_token(token)
assert isinstance(payload, TokenPayload)
assert payload.sub == "550e8400-e29b-41d4-a716-446655440000"
assert payload.email == "test@wealthwise.app"
assert payload.aud == "authenticated"
def test_validate_expired_token(self):
"""Test that expired tokens are rejected."""
token = self.create_test_token(expired=True)
with pytest.raises(HTTPException) as exc_info:
validate_supabase_token(token)
assert exc_info.value.status_code == 401
assert "Could not validate credentials" in exc_info.value.detail
def test_validate_wrong_audience(self):
"""Test that tokens with wrong audience are rejected."""
token = self.create_test_token(wrong_audience=True)
with pytest.raises(HTTPException) as exc_info:
validate_supabase_token(token)
assert exc_info.value.status_code == 401
def test_validate_invalid_signature(self):
"""Test that tokens with invalid signature are rejected."""
token = self.create_test_token(secret="wrong-secret")
with pytest.raises(HTTPException) as exc_info:
validate_supabase_token(token)
assert exc_info.value.status_code == 401
def test_validate_missing_exp(self):
"""Test that tokens without expiration are rejected."""
token = self.create_test_token(missing_claims=True)
with pytest.raises(HTTPException) as exc_info:
validate_supabase_token(token)
assert exc_info.value.status_code == 401
class TestGetCurrentUser:
"""Test cases for get_current_user dependency."""
def test_get_current_user_valid_token(self):
"""Test that valid token returns UserContext."""
# Create a valid token
payload = {
"sub": "550e8400-e29b-41d4-a716-446655440000",
"email": "user@wealthwise.app",
"aud": "authenticated",
"exp": int(time.time()) + 3600,
"role": "authenticated",
"app_metadata": {},
"user_metadata": {},
}
token = jwt.encode(
payload,
settings.SUPABASE_JWT_SECRET or settings.SECRET_KEY,
algorithm="HS256",
)
# Since get_current_user is async, we need to run it in an event loop
import asyncio
async def test():
return await get_current_user(token)
user = asyncio.run(test())
assert isinstance(user, UserContext)
assert user.id == "550e8400-e29b-41d4-a716-446655440000"
assert user.email == "user@wealthwise.app"
assert user.role == "authenticated"
def test_get_current_user_no_token(self):
"""Test that missing token raises 401."""
import asyncio
async def test():
return await get_current_user(None)
with pytest.raises(HTTPException) as exc_info:
asyncio.run(test())
assert exc_info.value.status_code == 401
def test_get_current_user_invalid_token(self):
"""Test that invalid token raises 401."""
import asyncio
async def test():
return await get_current_user("invalid-token")
with pytest.raises(HTTPException) as exc_info:
asyncio.run(test())
assert exc_info.value.status_code == 401
class TestUserContext:
"""Test cases for UserContext model."""
def test_from_token_payload(self):
"""Test conversion from TokenPayload to UserContext."""
payload = TokenPayload(
sub="550e8400-e29b-41d4-a716-446655440000",
email="test@wealthwise.app",
aud="authenticated",
role="authenticated",
app_metadata={"role": "admin"},
)
context = UserContext.from_token_payload(payload)
assert context.id == "550e8400-e29b-41d4-a716-446655440000"
assert context.email == "test@wealthwise.app"
assert context.role == "admin" # From app_metadata
def test_from_token_payload_no_app_metadata_role(self):
"""Test fallback to role claim when app_metadata has no role."""
payload = TokenPayload(
sub="550e8400-e29b-41d4-a716-446655440000",
email="test@wealthwise.app",
aud="authenticated",
role="authenticated",
app_metadata={},
)
context = UserContext.from_token_payload(payload)
assert context.role == "authenticated" # From role claim
class TestProtectedEndpoints:
"""Integration tests for protected endpoints."""
def test_me_endpoint_without_auth(self):
"""Test that /me endpoint requires authentication."""
client = TestClient(app)
response = client.get("/api/v1/users/me")
assert response.status_code == 401
def test_me_endpoint_with_valid_auth(self):
"""Test that /me endpoint works with valid token."""
# Create valid token
payload = {
"sub": "550e8400-e29b-41d4-a716-446655440000",
"email": "user@wealthwise.app",
"aud": "authenticated",
"exp": int(time.time()) + 3600,
"role": "authenticated",
"app_metadata": {},
"user_metadata": {},
}
token = jwt.encode(
payload,
settings.SUPABASE_JWT_SECRET or settings.SECRET_KEY,
algorithm="HS256",
)
client = TestClient(app)
response = client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == "550e8400-e29b-41d4-a716-446655440000"
assert data["email"] == "user@wealthwise.app"