🚀 Day 17: Advanced Middleware ​
Custom Middleware ​
Creating custom middleware components:
python
from nexios import get_application
from nexios.http import Request, Response
from nexios.middleware import Middleware
from typing import Callable, Optional
import time
import uuid
app = get_application()
# Timing middleware
class TimingMiddleware(Middleware):
async def __call__(
self,
request: Request,
call_next: Callable
) -> Response:
start_time = time.time()
response = await call_next(request)
# Add timing header
process_time = time.time() - start_time
response.headers["X-Process-Time"] = f"{process_time:.4f}"
return response
# Request ID middleware
class RequestIDMiddleware(Middleware):
def __init__(self, header_name: str = "X-Request-ID"):
self.header_name = header_name
async def __call__(
self,
request: Request,
call_next: Callable
) -> Response:
# Get or generate request ID
request_id = request.headers.get(
self.header_name,
str(uuid.uuid4())
)
# Add to request state
request.state.request_id = request_id
response = await call_next(request)
# Add to response headers
response.headers[self.header_name] = request_id
return response
# Rate limiting middleware
class RateLimitMiddleware(Middleware):
def __init__(
self,
requests_per_minute: int = 60,
block_duration: int = 60
):
self.limit = requests_per_minute
self.duration = block_duration
self.requests = {}
async def __call__(
self,
request: Request,
call_next: Callable
) -> Response:
client_ip = request.client.host
current_time = time.time()
# Clean old records
self.clean_old_requests(current_time)
# Check rate limit
if self.is_rate_limited(client_ip, current_time):
return Response(
content={"error": "Rate limit exceeded"},
status_code=429
)
# Record request
self.record_request(client_ip, current_time)
return await call_next(request)
def clean_old_requests(self, current_time: float):
cutoff = current_time - self.duration
self.requests = {
ip: times for ip, times in self.requests.items()
if any(t > cutoff for t in times)
}
def is_rate_limited(self, ip: str, current_time: float) -> bool:
if ip not in self.requests:
return False
recent_requests = [
t for t in self.requests[ip]
if t > current_time - self.duration
]
return len(recent_requests) >= self.limit
def record_request(self, ip: str, time: float):
if ip not in self.requests:
self.requests[ip] = []
self.requests[ip].append(time)
# Add middleware to app
app.add_middleware(TimingMiddleware)
app.add_middleware(RequestIDMiddleware)
app.add_middleware(RateLimitMiddleware, requests_per_minute=30)
Middleware Chains ​
Managing middleware execution order:
python
from nexios.middleware import MiddlewareChain
from typing import List, Type
class MiddlewareManager:
def __init__(self):
self.chains: Dict[str, MiddlewareChain] = {}
def create_chain(
self,
name: str,
middlewares: List[Type[Middleware]]
):
chain = MiddlewareChain()
for middleware in middlewares:
chain.add(middleware)
self.chains[name] = chain
def get_chain(self, name: str) -> Optional[MiddlewareChain]:
return self.chains.get(name)
# Initialize manager
middleware_manager = MiddlewareManager()
# Create middleware chains
middleware_manager.create_chain(
"api",
[
TimingMiddleware,
RequestIDMiddleware,
RateLimitMiddleware
]
)
middleware_manager.create_chain(
"web",
[
TimingMiddleware,
RequestIDMiddleware
]
)
# Apply middleware chain to router
api_router = Router(prefix="/api")
api_chain = middleware_manager.get_chain("api")
api_router.middleware = api_chain
web_router = Router(prefix="/web")
web_chain = middleware_manager.get_chain("web")
web_router.middleware = web_chain
# Include routers
app.include_router(api_router)
app.include_router(web_router)
Global Middleware ​
Implementing application-wide middleware:
python
from nexios.security import SecurityHeaders
from nexios.cors import CORSMiddleware
from nexios.compression import CompressionMiddleware
# Security headers middleware
class SecurityHeadersMiddleware(Middleware):
def __init__(self, **options):
self.headers = SecurityHeaders(**options)
async def __call__(
self,
request: Request,
call_next: Callable
) -> Response:
response = await call_next(request)
# Add security headers
headers = self.headers.get_headers()
response.headers.update(headers)
return response
# Add global middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_methods=["*"],
allow_headers=["*"]
)
app.add_middleware(
SecurityHeadersMiddleware,
xss_protection=True,
content_type_options=True,
frame_options="DENY",
hsts=True
)
app.add_middleware(
CompressionMiddleware,
minimum_size=1000,
compression_level=6
)
# Order-specific middleware
@app.middleware("http", order=1)
async def first_middleware(request: Request, call_next):
# Executes first
return await call_next(request)
@app.middleware("http", order=2)
async def second_middleware(request: Request, call_next):
# Executes second
return await call_next(request)
Context Management ​
Managing request context in middleware:
python
from contextvars import ContextVar
from typing import Optional, Any
# Context variables
request_id: ContextVar[str] = ContextVar("request_id")
current_user: ContextVar[Optional[dict]] = ContextVar(
"current_user",
default=None
)
trace_id: ContextVar[str] = ContextVar("trace_id")
class ContextMiddleware(Middleware):
async def __call__(
self,
request: Request,
call_next: Callable
) -> Response:
# Set context variables
request_id.set(str(uuid.uuid4()))
trace_id.set(
request.headers.get("X-Trace-ID", str(uuid.uuid4()))
)
try:
return await call_next(request)
finally:
# Clean up context
request_id.set(None)
trace_id.set(None)
# Context manager for database transactions
class TransactionContext:
def __init__(self, db):
self.db = db
async def __aenter__(self):
self.transaction = await self.db.transaction()
return self.transaction
async def __aexit__(self, exc_type, exc, tb):
if exc_type is None:
await self.transaction.commit()
else:
await self.transaction.rollback()
# Transaction middleware
class TransactionMiddleware(Middleware):
def __init__(self, db):
self.db = db
async def __call__(
self,
request: Request,
call_next: Callable
) -> Response:
async with TransactionContext(self.db):
return await call_next(request)
# Context-aware logging
class LogContext:
def __init__(self):
self.request_id = request_id.get()
self.trace_id = trace_id.get()
self.user = current_user.get()
def get_context(self) -> dict:
return {
"request_id": self.request_id,
"trace_id": self.trace_id,
"user_id": self.user.get("id") if self.user else None
}
# Logging middleware with context
class ContextualLoggingMiddleware(Middleware):
def __init__(self, logger):
self.logger = logger
async def __call__(
self,
request: Request,
call_next: Callable
) -> Response:
context = LogContext()
self.logger.info(
"Request started",
extra={
**context.get_context(),
"method": request.method,
"path": request.url.path
}
)
try:
response = await call_next(request)
self.logger.info(
"Request completed",
extra={
**context.get_context(),
"status_code": response.status_code
}
)
return response
except Exception as e:
self.logger.error(
"Request failed",
extra={
**context.get_context(),
"error": str(e)
},
exc_info=True
)
raise
📝 Practice Exercise ​
Create custom middleware:
- Authentication
- Caching
- Request validation
- Response transformation
Implement middleware chains:
- Route-specific chains
- Conditional middleware
- Dynamic chain building
- Chain ordering
Build context management:
- Request tracking
- User context
- Resource cleanup
- Error handling
📚 Additional Resources ​
🎯 Next Steps ​
Tomorrow in Day 18: Advanced Routing, we'll explore:
- Route groups
- Dynamic routes
- Route dependencies
- URL generation