"""Template FastAPI server module."""
import json
import logging
import os
import sys
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager
from importlib.metadata import metadata
from pathlib import Path
from typing import Any
import dotenv
import uvicorn
from fastapi import FastAPI, HTTPException, Request, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pydantic_core import ValidationError
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from starlette.exceptions import HTTPException as StarletteHTTPException
from template_python.logging_setup import add_file_handler, setup_default_logging
from python_template_server.authentication_handler import verify_token
from python_template_server.certificate_handler import CertificateHandler
from python_template_server.constants import (
API_KEY_HEADER_NAME,
API_PREFIX,
CONFIG_FILE_PATH,
ENV_FILE_PATH,
LOGGING_BACKUP_COUNT,
LOGGING_FILE_PATH,
LOGGING_MAX_BYTES_MB,
MB_TO_BYTES,
STATIC_DIR,
)
from python_template_server.middleware import RequestLoggingMiddleware, SecurityHeadersMiddleware
from python_template_server.models import (
CustomJSONResponse,
GetHealthResponse,
GetLoginResponse,
ResponseCode,
TemplateServerConfig,
)
setup_default_logging()
add_file_handler(
logging_filepath=LOGGING_FILE_PATH,
max_bytes=LOGGING_MAX_BYTES_MB * MB_TO_BYTES,
backup_count=LOGGING_BACKUP_COUNT,
)
logger = logging.getLogger(__name__)
[docs]
class TemplateServer(ABC):
"""Template FastAPI server.
This class provides a template for building FastAPI servers with common features
such as request logging, security headers and rate limiting.
Ensure you implement the `setup_routes` and `validate_config` methods in subclasses.
"""
[docs]
def __init__(
self,
package_name: str = "python-template-server",
api_prefix: str = API_PREFIX,
api_key_header_name: str = API_KEY_HEADER_NAME,
config_filepath: Path = CONFIG_FILE_PATH,
config: TemplateServerConfig | None = None,
static_dir: Path = STATIC_DIR,
) -> None:
"""Initialize the TemplateServer.
:param str api_prefix: The API prefix for the server
:param str api_key_header_name: The API key header name
:param Path config_filepath: Path to the configuration file
:param TemplateServerConfig | None config: Optional pre-loaded configuration
"""
dotenv.load_dotenv(ENV_FILE_PATH)
self.api_prefix = api_prefix
self.api_key_header_name = api_key_header_name
self.config_filepath = config_filepath
self.config = config or self.load_config(self.config_filepath)
self.cert_handler = CertificateHandler(self.config.certificate)
self.static_dir = static_dir
CustomJSONResponse.configure(self.config.json_response)
self.package_metadata = metadata(package_name)
self.app = FastAPI(
title=self.package_metadata["Name"],
description=self.package_metadata["Summary"],
version=self.package_metadata["Version"],
root_path=self.api_prefix,
lifespan=self.lifespan,
default_response_class=CustomJSONResponse,
)
self.api_key_header = APIKeyHeader(name=self.api_key_header_name, auto_error=False)
self.host = os.getenv("HOST", "localhost")
self.port = int(os.getenv("PORT", "443"))
self.hashed_token = os.getenv("API_TOKEN_HASH", "")
self._setup_request_logging()
self._setup_security_headers()
self._setup_cors()
self._setup_rate_limiting()
self._setup_routes()
@property
def static_dir_exists(self) -> bool:
"""Check if the static directory exists.
:return bool: True if the static directory exists, False otherwise
"""
return self.static_dir.exists() and (self.static_dir / "index.html").exists()
[docs]
@staticmethod
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
"""Handle application lifespan events."""
yield
[docs]
@abstractmethod
def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig:
"""Validate configuration data against the TemplateServerConfig model.
:param dict config_data: The configuration data to validate
:return TemplateServerConfig: The validated configuration model
:raise ValidationError: If the configuration data is invalid
"""
return TemplateServerConfig.model_validate(config_data)
[docs]
def load_config(self, config_filepath: Path) -> TemplateServerConfig:
"""Load configuration from the specified json file.
:param Path config_filepath: Path to the configuration file
:return TemplateServerConfig: The validated configuration model
:raise SystemExit: If configuration file is missing, invalid JSON, or fails validation
"""
if not config_filepath.exists():
logger.error("Configuration file not found: %s", config_filepath)
sys.exit(1)
try:
with config_filepath.open() as f:
config_data = json.load(f)
config = self.validate_config(config_data)
config.save_to_file(config_filepath)
except json.JSONDecodeError:
logger.exception("JSON parsing error: %s", config_filepath)
sys.exit(1)
except OSError:
logger.exception("JSON read error: %s", config_filepath)
sys.exit(1)
except ValidationError:
logger.exception("Invalid configuration in: %s", config_filepath)
sys.exit(1)
else:
return config
async def _verify_api_key(
self, api_key: str | None = Security(APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False))
) -> None:
"""Verify the API key from the request header.
:param str | None api_key: The API key from the X-API-Key header
:raise HTTPException: If the API key is missing or invalid
"""
if api_key is None:
logger.warning("Missing API key in request!")
raise HTTPException(
status_code=ResponseCode.UNAUTHORIZED,
detail="Missing API key",
)
try:
if not verify_token(api_key, self.hashed_token):
logger.warning("Invalid API key attempt!")
raise HTTPException(
status_code=ResponseCode.UNAUTHORIZED,
detail="Invalid API key",
)
logger.debug("API key validated successfully.")
except ValueError as e:
logger.exception("Error verifying API key!")
raise HTTPException(
status_code=ResponseCode.UNAUTHORIZED,
detail=str(e),
) from e
def _setup_request_logging(self) -> None:
"""Set up request logging middleware."""
self.app.add_middleware(RequestLoggingMiddleware)
logger.info("Request logging enabled")
def _setup_security_headers(self) -> None:
"""Set up security headers middleware."""
self.app.add_middleware(
SecurityHeadersMiddleware,
hsts_max_age=self.config.security.hsts_max_age,
csp=self.config.security.content_security_policy,
)
logger.info(
"Security headers enabled: HSTS max-age=%s, CSP=%s",
self.config.security.hsts_max_age,
self.config.security.content_security_policy,
)
def _setup_cors(self) -> None:
"""Set up CORS middleware."""
if not self.config.cors.enabled:
logger.info("CORS is disabled")
return
self.app.add_middleware(
CORSMiddleware,
allow_origins=self.config.cors.allow_origins,
allow_credentials=self.config.cors.allow_credentials,
allow_methods=self.config.cors.allow_methods,
allow_headers=self.config.cors.allow_headers,
expose_headers=self.config.cors.expose_headers,
max_age=self.config.cors.max_age,
)
logger.info(
"CORS enabled: origins=%s, credentials=%s, methods=%s, headers=%s",
self.config.cors.allow_origins,
self.config.cors.allow_credentials,
self.config.cors.allow_methods,
self.config.cors.allow_headers,
)
async def _rate_limit_exception_handler(self, request: Request, exc: RateLimitExceeded) -> CustomJSONResponse:
"""Handle rate limit exceeded exceptions.
:param Request request: The incoming HTTP request
:param RateLimitExceeded exc: The rate limit exceeded exception
:return JSONResponse: HTTP 429 JSON response
"""
logger.warning("Rate limit exceeded for %s", request.url.path)
return CustomJSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"Retry-After": str(exc.retry_after)} if hasattr(exc, "retry_after") else {},
)
def _setup_rate_limiting(self) -> None:
"""Set up rate limiting middleware."""
if not self.config.rate_limit.enabled:
logger.info("Rate limiting is disabled")
self.limiter = None
return
self.limiter = Limiter(
key_func=get_remote_address,
storage_uri=self.config.rate_limit.storage_uri,
)
self.app.state.limiter = self.limiter
self.app.add_exception_handler(RateLimitExceeded, self._rate_limit_exception_handler) # type: ignore[arg-type]
logger.info(
"Rate limiting enabled: rate=%s, storage=%s",
self.config.rate_limit.rate_limit,
self.config.rate_limit.storage_uri or "in-memory",
)
def _limit_route(self, route_function: Callable[..., Any]) -> Callable[..., Any]:
"""Apply rate limiting to a route function if enabled.
:param Callable route_function: The route handler function
:return Callable: The potentially rate-limited route handler
"""
if self.limiter is not None:
return self.limiter.limit(self.config.rate_limit.rate_limit)(route_function) # type: ignore[no-any-return]
return route_function
[docs]
def run(self) -> None:
"""Run the server using uvicorn."""
try:
cert_file = self.config.certificate.ssl_cert_file_path
key_file = self.config.certificate.ssl_key_file_path
if not (cert_file.exists() and key_file.exists()):
logger.warning("SSL certificate or key file not found, generating self-signed certificate...")
self.cert_handler.generate_self_signed_cert()
logger.info("Starting server: https://%s:%s%s", self.host, self.port, self.api_prefix)
uvicorn.run(
app=self.app,
host=self.host,
port=self.port,
ssl_keyfile=str(key_file),
ssl_certfile=str(cert_file),
log_level="warning",
access_log=False,
)
logger.info("Server stopped.")
except Exception:
logger.exception("Failed to start!")
sys.exit(1)
[docs]
def add_unauthenticated_route(
self,
endpoint: str,
handler_function: Callable,
response_model: type[BaseModel] | None,
methods: list[str],
limited: bool = True, # noqa: FBT001, FBT002
) -> None:
"""Add an unauthenticated API route.
:param str endpoint: The API endpoint path
:param Callable handler_function: The handler function for the endpoint
:param BaseModel response_model: The Pydantic model for the response
:param list[str] methods: The HTTP methods for the endpoint
:param bool limited: Whether to apply rate limiting to this route
"""
self.app.add_api_route(
endpoint,
self._limit_route(handler_function) if limited else handler_function,
methods=methods,
response_model=response_model,
)
[docs]
def add_authenticated_route(
self,
endpoint: str,
handler_function: Callable,
response_model: type[BaseModel],
methods: list[str],
limited: bool = True, # noqa: FBT001, FBT002
) -> None:
"""Add an authenticated API route.
:param str endpoint: The API endpoint path
:param Callable handler_function: The handler function for the endpoint
:param BaseModel response_model: The Pydantic model for the response
:param list[str] methods: The HTTP methods for the endpoint
:param bool limited: Whether to apply rate limiting to this route
"""
self.app.add_api_route(
endpoint,
self._limit_route(handler_function) if limited else handler_function,
methods=methods,
response_model=response_model,
dependencies=[Security(self._verify_api_key)],
)
def _setup_routes(self) -> None:
"""Set up API routes."""
self.add_unauthenticated_route(
endpoint="/health",
handler_function=self.get_health,
response_model=GetHealthResponse,
methods=["GET"],
limited=False,
)
self.add_authenticated_route(
endpoint="/login",
handler_function=self.get_login,
response_model=GetLoginResponse,
methods=["GET"],
limited=True,
)
self.setup_routes()
if self.static_dir_exists:
logger.info("Mounting static directory: %s", self.static_dir)
self.app.mount("/", StaticFiles(directory=str(self.static_dir), html=True), name="static")
@self.app.exception_handler(StarletteHTTPException)
async def custom_404_handler(request: Request, exc: StarletteHTTPException) -> FileResponse:
"""Handle 404 errors by serving custom 404.html if available."""
if exc.status_code == ResponseCode.NOT_FOUND and self.static_dir_exists:
not_found_page = self.static_dir / "404.html"
if not_found_page.is_file():
return FileResponse(not_found_page, status_code=ResponseCode.NOT_FOUND)
raise exc
[docs]
@abstractmethod
def setup_routes(self) -> None:
"""Add custom API routes.
This method must be implemented by subclasses to define API endpoints
using `add_unauthenticated_route` and `add_authenticated_route`.
"""
pass
[docs]
async def get_health(self, request: Request) -> GetHealthResponse:
"""Get server health.
:param Request request: The incoming HTTP request
:return GetHealthResponse: Health status response
:raise HTTPException: If the server token is not configured
"""
if not self.hashed_token:
raise HTTPException(
status_code=ResponseCode.INTERNAL_SERVER_ERROR,
detail="Server token is not configured",
)
return GetHealthResponse(
message="Server is healthy",
timestamp=GetHealthResponse.current_timestamp(),
)
[docs]
async def get_login(self, request: Request) -> GetLoginResponse:
"""Handle user login and return a success response.
:param Request request: The incoming HTTP request
:return GetLoginResponse: Login success response
:raise HTTPException: If the server token is not configured
"""
if not self.hashed_token:
raise HTTPException(
status_code=ResponseCode.INTERNAL_SERVER_ERROR,
detail="Server token is not configured",
)
logger.info("User login successful.")
return GetLoginResponse(
message="Login successful.",
timestamp=GetLoginResponse.current_timestamp(),
)