""" Server-Sent Events (SSE) service for real-time status updates. This module implements an event broadcaster pattern using asyncio.Queue per client. When receipt status changes occur (CRUD operations), events are pushed to all connected clients who are listening for that specific batch or all receipts. Usage: # In router endpoint (SSE stream): async for event in sse_service.subscribe(batch_id=None): yield event # When status changes (from CRUD operations): await sse_service.broadcast_status_change(receipt_id, status, processing_status, batch_id) """ import asyncio import json import logging from dataclasses import dataclass, asdict from typing import AsyncGenerator, Optional from datetime import datetime logger = logging.getLogger(__name__) @dataclass class StatusChangeEvent: """Event data for receipt status changes.""" receipt_id: int status: str processing_status: Optional[str] = None batch_id: Optional[str] = None timestamp: Optional[str] = None def __post_init__(self): if self.timestamp is None: self.timestamp = datetime.utcnow().isoformat() def to_sse_data(self) -> str: """Format as SSE data line.""" data = asdict(self) return f"data: {json.dumps(data)}\n\n" class SSEEventBroadcaster: """ Manages SSE client connections and broadcasts events. Each client gets its own asyncio.Queue. When an event occurs, it's pushed to all relevant queues based on batch_id filtering. """ def __init__(self): # Dict of {client_id: (queue, batch_id_filter)} # batch_id_filter is None for clients that want all events self._clients: dict[str, tuple[asyncio.Queue, Optional[str]]] = {} self._client_counter = 0 self._lock = asyncio.Lock() async def _generate_client_id(self) -> str: """Generate unique client ID.""" async with self._lock: self._client_counter += 1 return f"client_{self._client_counter}_{datetime.utcnow().timestamp()}" async def subscribe( self, batch_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: """ Subscribe to SSE events. Args: batch_id: Optional filter - only receive events for this batch. If None, receives all events. Yields: SSE-formatted event strings (ready to send to client). """ client_id = await self._generate_client_id() queue: asyncio.Queue = asyncio.Queue() # Register client async with self._lock: self._clients[client_id] = (queue, batch_id) logger.info( f"SSE client {client_id} connected (batch_id filter: {batch_id}). " f"Total clients: {len(self._clients)}" ) try: # Send initial retry hint for reconnection yield "retry: 3000\n\n" # Keep connection alive and yield events while True: try: # Wait for events with timeout for keep-alive event = await asyncio.wait_for(queue.get(), timeout=30.0) yield event except asyncio.TimeoutError: # Send keep-alive comment to prevent connection timeout yield ": keep-alive\n\n" except asyncio.CancelledError: logger.info(f"SSE client {client_id} subscription cancelled") raise finally: # Cleanup: remove client from registry async with self._lock: self._clients.pop(client_id, None) logger.info( f"SSE client {client_id} disconnected. " f"Remaining clients: {len(self._clients)}" ) async def broadcast_status_change( self, receipt_id: int, status: str, processing_status: Optional[str] = None, batch_id: Optional[str] = None, ) -> int: """ Broadcast a status change event to all relevant clients. Args: receipt_id: The receipt ID that changed. status: New workflow status (DRAFT, PENDING_REVIEW, etc.). processing_status: New processing status (pending, processing, completed, failed). batch_id: The batch ID this receipt belongs to (for filtering). Returns: Number of clients notified. """ event = StatusChangeEvent( receipt_id=receipt_id, status=status, processing_status=processing_status, batch_id=batch_id, ) sse_data = event.to_sse_data() notified = 0 async with self._lock: for client_id, (queue, client_batch_filter) in self._clients.items(): # Send event if: # 1. Client has no filter (wants all events), OR # 2. Client's filter matches the event's batch_id if client_batch_filter is None or client_batch_filter == batch_id: try: queue.put_nowait(sse_data) notified += 1 except asyncio.QueueFull: logger.warning( f"SSE queue full for client {client_id}, dropping event" ) if notified > 0: logger.debug( f"SSE broadcast: receipt_id={receipt_id}, status={status}, " f"processing_status={processing_status}, notified={notified} clients" ) return notified @property def client_count(self) -> int: """Get current number of connected clients.""" return len(self._clients) # Singleton instance for the application sse_broadcaster = SSEEventBroadcaster() # Convenience functions for external use async def subscribe(batch_id: Optional[str] = None) -> AsyncGenerator[str, None]: """Subscribe to SSE status change events.""" async for event in sse_broadcaster.subscribe(batch_id): yield event async def broadcast_status_change( receipt_id: int, status: str, processing_status: Optional[str] = None, batch_id: Optional[str] = None, ) -> int: """Broadcast a status change event.""" return await sse_broadcaster.broadcast_status_change( receipt_id=receipt_id, status=status, processing_status=processing_status, batch_id=batch_id, )