Godwin Vincent cd0a29d2ae
Device management agent - AgentCore runtime, observability, frontend added (#241)
* updated README.md file with bearer token generation

* updated README.md file with bearer token generation-removed client id and secret credentials

* removed hardcoded domain

* added agent runtime, frontend, observability and agentcore identity

* update README.md file to reflect frontend testing
2025-08-13 09:31:29 -07:00

632 lines
27 KiB
Python

"""
FastAPI application for Device Management using Bedrock AgentCore
"""
import os
import json
import logging
import uuid
import secrets
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, HTTPException, Depends, Form
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from pydantic import BaseModel
from dotenv import load_dotenv
import boto3
# Import authentication module
from auth import get_login_url, exchange_code_for_tokens, validate_token, get_current_user, login_required, get_logout_url
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="Device Management Chat Application")
# Add session middleware with a secure random key
app.add_middleware(
SessionMiddleware,
secret_key=secrets.token_urlsafe(32),
max_age=3600 # 1 hour session
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set up templates and static files
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Environment variables
AWS_REGION = os.getenv("AWS_REGION", "us-west-2")
AGENT_ARN = os.getenv("AGENT_ARN")
if not AGENT_ARN:
logger.error("AGENT_ARN environment variable is not set")
raise ValueError("AGENT_ARN environment variable is required")
# Models
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
# Connection manager for WebSockets
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.session_ids: Dict[str, str] = {} # Map client_id to runtime_session_id
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
self.session_ids[client_id] = None # Initialize with no session
def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
if client_id in self.session_ids:
del self.session_ids[client_id]
async def send_message(self, message: str, client_id: str):
if client_id in self.active_connections:
await self.active_connections[client_id].send_text(message)
def get_session_id(self, client_id: str) -> Optional[str]:
return self.session_ids.get(client_id)
def set_session_id(self, client_id: str, session_id: str):
self.session_ids[client_id] = session_id
manager = ConnectionManager()
def parse_streaming_response(content):
"""Parse streaming response content to extract the final response text"""
try:
logger.debug(f"Parsing streaming content: {len(content)} characters")
# Split content by lines and look for the final response
lines = content.strip().split('\n')
final_response = ""
accumulated_text = ""
# Look for the final complete message at the end
for line in reversed(lines):
line = line.strip()
if not line:
continue
try:
# Try to parse as JSON directly (for AgentCore responses)
if line.startswith('{') and line.endswith('}'):
json_data = json.loads(line)
# Handle SSE format with 'data: ' prefix
elif line.startswith('data: '):
json_str = line[6:].strip()
if not json_str:
continue
json_data = json.loads(json_str)
else:
continue
# Look for the final complete response first
if isinstance(json_data, dict):
# Check for complete type with final_response (highest priority)
if json_data.get('type') == 'complete' and 'final_response' in json_data:
final_response = json_data['final_response']
logger.debug("Found complete type with final_response")
break
# Check for message content with complete response
elif 'message' in json_data:
message = json_data['message']
if isinstance(message, dict) and 'content' in message:
content_list = message['content']
if isinstance(content_list, list):
text_parts = []
for item in content_list:
if isinstance(item, dict) and 'text' in item:
text_parts.append(item['text'])
if text_parts:
candidate_response = ' '.join(text_parts)
# Only use if it's a substantial response (likely the final one)
if len(candidate_response) > 200:
final_response = candidate_response
logger.debug("Found substantial message content")
break
except json.JSONDecodeError as e:
logger.debug(f"Failed to parse JSON line: {line[:100]}... Error: {e}")
continue
except Exception as e:
logger.debug(f"Error processing line: {e}")
continue
# If no final response found, accumulate text chunks
if not final_response:
for line in lines:
line = line.strip()
if not line:
continue
try:
if line.startswith('{') and line.endswith('}'):
json_data = json.loads(line)
elif line.startswith('data: '):
json_str = line[6:].strip()
if not json_str:
continue
json_data = json.loads(json_str)
else:
continue
if isinstance(json_data, dict):
# Check for streaming text chunks
if 'event' in json_data:
event = json_data['event']
if isinstance(event, dict):
# Handle contentBlockDelta events
if 'contentBlockDelta' in event:
delta = event['contentBlockDelta']
if isinstance(delta, dict) and 'delta' in delta:
delta_data = delta['delta']
if isinstance(delta_data, dict) and 'text' in delta_data:
accumulated_text += delta_data['text']
# Check for chunk data
elif 'data' in json_data and isinstance(json_data['data'], str):
accumulated_text += json_data['data']
except json.JSONDecodeError:
continue
except Exception:
continue
# Return the best response we found
if final_response:
logger.info(f"Extracted final response: {len(final_response)} characters")
return final_response
elif accumulated_text:
logger.info(f"Using accumulated text: {len(accumulated_text)} characters")
return accumulated_text
else:
logger.warning("No response text found in streaming data")
return f"No parseable response found. Raw content sample: {content[:500]}..."
except Exception as e:
logger.error(f"Error parsing streaming response: {str(e)}")
return f"Error parsing response: {str(e)}"
def format_response_text(text):
"""Format response text for better readability in the UI"""
if not text:
return ""
try:
# Clean up the text first
text = text.strip()
# Try to parse as JSON if it looks like JSON
if (text.startswith('{') and text.endswith('}')) or \
(text.startswith('[') and text.endswith(']')):
try:
parsed = json.loads(text)
# If it's a list of devices, format it nicely
if isinstance(parsed, list) and len(parsed) > 0 and isinstance(parsed[0], dict):
# Check if this looks like a device list
if all('device_id' in item or 'name' in item for item in parsed):
result = "📱 **Device List:**\n\n"
for i, item in enumerate(parsed, 1):
name = item.get('name', 'Unknown Device')
device_id = item.get('device_id', item.get('id', 'Unknown ID'))
status = item.get('connection_status', item.get('status', 'Unknown'))
# Add status emoji
status_emoji = {
'Connected': '🟢',
'Disconnected': '🔴',
'Updating': '🟡',
'Dormant': '🟠',
'Maintenance': '🔧'
}.get(status, '')
result += f"**{i}. {name}** {status_emoji}\n"
result += f" • ID: `{device_id}`\n"
if 'model' in item:
result += f" • Model: {item['model']}\n"
if 'ip_address' in item:
result += f" • IP: {item['ip_address']}\n"
if 'connection_status' in item:
result += f" • Status: {item['connection_status']}\n"
if 'firmware_version' in item:
result += f" • Firmware: {item['firmware_version']}\n"
if 'last_connected' in item:
# Format the timestamp nicely
timestamp = item['last_connected']
if 'T' in timestamp:
date_part = timestamp.split('T')[0]
time_part = timestamp.split('T')[1].split('.')[0]
result += f" • Last Connected: {date_part} at {time_part}\n"
else:
result += f" • Last Connected: {timestamp}\n"
result += "\n"
return result.strip()
# For other JSON data, pretty print with indentation
return f"```json\n{json.dumps(parsed, indent=2)}\n```"
except json.JSONDecodeError:
# Not valid JSON, continue with regular formatting
pass
# Replace escaped characters
text = text.replace('\\n', '\n').replace('\\"', '"').replace("\\'", "'")
# Format bullet points consistently
lines = text.split('\n')
formatted_lines = []
for line in lines:
line = line.strip()
if not line:
formatted_lines.append('')
continue
# Convert numbered lists to bullet points
if line and len(line) > 2 and line[0].isdigit() and line[1:3] in ['. ', ') ']:
line = '' + line.split('. ', 1)[1] if '. ' in line else '' + line.split(') ', 1)[1]
# Ensure consistent bullet formatting
elif line.startswith('- '):
line = '' + line[2:]
# Format key-value pairs nicely
elif ':' in line and not line.startswith(' ') and not line.startswith(''):
parts = line.split(':', 1)
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
key = parts[0].strip()
value = parts[1].strip()
line = f"**{key}:** {value}"
formatted_lines.append(line)
result = '\n'.join(formatted_lines)
# Clean up excessive whitespace
while '\n\n\n' in result:
result = result.replace('\n\n\n', '\n\n')
return result.strip()
except Exception as e:
logger.error(f"Error formatting response: {str(e)}")
return text # Return original text if formatting fails
def create_agentcore_client(auth_token=None):
"""Create AgentCore client and boto session"""
# Create boto session
boto_session = boto3.Session(region_name=AWS_REGION)
# Create bedrock-agentcore client
agentcore_client = boto_session.client(
'bedrock-agentcore',
region_name=AWS_REGION
)
return agentcore_client
# Routes
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
"""WebSocket endpoint for real-time chat with streaming support"""
await manager.connect(websocket, client_id)
try:
# Create AgentCore client
agentcore_client = create_agentcore_client()
while True:
data = await websocket.receive_text()
user_message = data.strip()
if not user_message:
await manager.send_message(json.dumps({"error": "Empty message"}), client_id)
continue
try:
# Get current session ID for this client
session_id = manager.get_session_id(client_id)
# Invoke the agent with retry logic
import time
from botocore.exceptions import ClientError
max_retries = 3
retry_delay = 1 # Start with 1 second delay
for attempt in range(max_retries):
try:
if session_id is None:
# First message in conversation
logger.info("Starting new conversation with streaming")
boto3_response = agentcore_client.invoke_agent_runtime(
agentRuntimeArn=AGENT_ARN,
qualifier="DEFAULT",
payload=json.dumps({"prompt": user_message})
)
else:
# Continuing conversation with existing session ID
logger.info(f"Continuing conversation with session ID: {session_id}")
boto3_response = agentcore_client.invoke_agent_runtime(
agentRuntimeArn=AGENT_ARN,
qualifier="DEFAULT",
payload=json.dumps({"prompt": user_message}),
runtimeSessionId=session_id
)
# If successful, break out of retry loop
break
except ClientError as e:
if e.response['Error']['Code'] == 'throttlingException' and attempt < max_retries - 1:
logger.warning(f"Throttling exception encountered. Retrying in {retry_delay} seconds...")
await manager.send_message(json.dumps({"status": f"Rate limited. Retrying in {retry_delay} seconds..."}), client_id)
time.sleep(retry_delay)
# Exponential backoff
retry_delay *= 2
else:
# Re-raise the exception if we've exhausted retries or it's not a throttling exception
raise
# Update session ID
if isinstance(boto3_response, dict) and 'runtimeSessionId' in boto3_response:
new_session_id = boto3_response['runtimeSessionId']
logger.info(f"Received new session ID: {new_session_id}")
manager.set_session_id(client_id, new_session_id)
else:
logger.warning("No runtimeSessionId in response")
# Keep using the existing session ID if available
new_session_id = session_id
# Process streaming response
full_response = ""
# Handle streaming response from AgentCore
if isinstance(boto3_response, dict) and "response" in boto3_response:
try:
response_stream = boto3_response["response"]
logger.info(f"Processing streaming response, type: {type(response_stream)}")
# Handle StreamingBody properly
if hasattr(response_stream, 'read'):
content = response_stream.read()
if isinstance(content, bytes):
content = content.decode('utf-8')
logger.debug(f"Raw streaming content received: {len(content)} characters")
# Parse the streaming content to extract the final response
final_response_text = parse_streaming_response(content)
if final_response_text:
# Send final complete message
await manager.send_message(json.dumps({
"response": format_response_text(final_response_text),
"sessionId": new_session_id,
"complete": True
}), client_id)
else:
await manager.send_message(json.dumps({
"error": "No valid response content found in streaming data"
}), client_id)
else:
# Fallback: convert to string
content = str(response_stream)
final_response_text = parse_streaming_response(content)
if final_response_text:
await manager.send_message(json.dumps({
"response": format_response_text(final_response_text),
"sessionId": new_session_id,
"complete": True
}), client_id)
else:
await manager.send_message(json.dumps({
"error": "No valid response content found"
}), client_id)
except Exception as e:
logger.error(f'Error processing streaming response: {str(e)}')
await manager.send_message(json.dumps({
'error': f'Error processing streaming response: {str(e)}'
}), client_id)
else:
# Fallback to non-streaming response handling
logger.warning('No streaming response found, falling back to non-streaming')
response_content = str(boto3_response)
formatted_response = format_response_text(response_content)
await manager.send_message(json.dumps({
'response': formatted_response,
'sessionId': new_session_id
}), client_id)
except Exception as e:
error_message = str(e)
logger.error(f'Error processing request with agent: {error_message}')
# Provide more helpful error messages for common issues
if 'throttlingException' in error_message:
error_message = 'Too many requests. The service is temporarily throttling requests. Please try again in a few moments.'
elif 'AccessDeniedException' in error_message:
error_message = 'Access denied. Please check your AWS credentials and permissions.'
elif 'ValidationException' in error_message and 'runtimeSessionId' in error_message:
error_message = 'Invalid session ID. Starting a new conversation.'
manager.set_session_id(client_id, None) # Reset the session ID
await manager.send_message(json.dumps({
'error': f'Error processing request with agent: {error_message}'
}), client_id)
except WebSocketDisconnect:
manager.disconnect(client_id)
except Exception as e:
logger.error(f'WebSocket error: {str(e)}')
manager.disconnect(client_id)
# Authentication routes
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
"""Root endpoint that redirects to login if not authenticated or chat if authenticated"""
user = await get_current_user(request)
if not user:
# Try simple login first as a fallback
return RedirectResponse(url="/simple-login")
return templates.TemplateResponse("chat.html", {"request": request, "user": user})
@app.get("/login", response_class=HTMLResponse)
async def login_page(request: Request):
"""Login page with Cognito authentication"""
user = await get_current_user(request)
if user:
return RedirectResponse(url="/")
login_url = get_login_url()
return templates.TemplateResponse("login.html", {"request": request, "login_url": login_url})
@app.get("/auth/callback")
async def auth_callback(
request: Request,
code: Optional[str] = None,
error: Optional[str] = None,
error_description: Optional[str] = None
):
"""Callback endpoint for Cognito authentication"""
# Check if there was an error in the authentication process
if error:
error_msg = f"Authentication error: {error}"
if error_description:
error_msg += f" - {error_description}"
logger.error(error_msg)
return templates.TemplateResponse(
"login.html",
{
"request": request,
"login_url": get_login_url(),
"error": error_msg
},
status_code=400
)
# If no code is provided, return an error
if not code:
logger.error("No authorization code provided")
return templates.TemplateResponse(
"login.html",
{
"request": request,
"login_url": get_login_url(),
"error": "No authorization code provided"
},
status_code=400
)
try:
# Exchange the authorization code for tokens
tokens = await exchange_code_for_tokens(code)
# Validate the ID token
id_token = tokens["id_token"]
claims = await validate_token(id_token)
# Store user info in session
request.session["user"] = {
"sub": claims["sub"],
"email": claims.get("email", ""),
"name": claims.get("name", ""),
"access_token": tokens["access_token"],
"id_token": id_token
}
# Redirect to the main page
return RedirectResponse(url="/")
except Exception as e:
logger.error(f"Authentication error: {str(e)}")
return templates.TemplateResponse(
"login.html",
{
"request": request,
"login_url": get_login_url(),
"error": str(e)
},
status_code=400
)
@app.get("/logout")
async def logout(request: Request):
"""Logout endpoint"""
# Clear the session
request.session.clear()
# Create response and clear all authentication cookies
response = RedirectResponse(url="/simple-login")
response.delete_cookie("access_token")
response.delete_cookie("simple_user")
return response
@app.get("/profile")
async def profile(request: Request, user: dict = Depends(login_required)):
"""User profile endpoint"""
return {"user": user}
@app.get("/simple-login", response_class=HTMLResponse)
async def simple_login_page(request: Request):
"""Simple login page without Cognito"""
user = await get_current_user(request)
if user:
return RedirectResponse(url="/")
return templates.TemplateResponse("simple_login.html", {"request": request})
@app.post("/simple-login")
async def simple_login_submit(request: Request, username: str = Form(...), password: str = Form(...)):
"""Process simple login form"""
# For demo purposes, accept any username/password
# In a real application, you would validate against a database or other authentication system
# Store user info in session
request.session["user"] = {
"sub": "simple-user-123",
"email": username,
"name": username,
"access_token": "demo-token",
"id_token": "demo-token"
}
# Redirect to the main page
return RedirectResponse(url="/", status_code=303) # 303 See Other is used for POST redirects
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)