mirror of
https://github.com/awslabs/amazon-bedrock-agentcore-samples.git
synced 2025-09-08 20:50:46 +00:00
* updated README.md file with bearer token generation * updated README.md file with bearer token generation-removed client id and secret credentials * removed hardcoded domain * added agent runtime, frontend, observability and agentcore identity * update README.md file to reflect frontend testing
632 lines
27 KiB
Python
632 lines
27 KiB
Python
"""
|
|
FastAPI application for Device Management using Bedrock AgentCore
|
|
"""
|
|
import os
|
|
import json
|
|
import logging
|
|
import uuid
|
|
import secrets
|
|
from typing import List, Dict, Any, Optional
|
|
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, HTTPException, Depends, Form
|
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
|
from fastapi.templating import Jinja2Templates
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from pydantic import BaseModel
|
|
from dotenv import load_dotenv
|
|
import boto3
|
|
|
|
# Import authentication module
|
|
from auth import get_login_url, exchange_code_for_tokens, validate_token, get_current_user, login_required, get_logout_url
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(title="Device Management Chat Application")
|
|
|
|
# Add session middleware with a secure random key
|
|
app.add_middleware(
|
|
SessionMiddleware,
|
|
secret_key=secrets.token_urlsafe(32),
|
|
max_age=3600 # 1 hour session
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Set up templates and static files
|
|
templates = Jinja2Templates(directory="templates")
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
# Environment variables
|
|
AWS_REGION = os.getenv("AWS_REGION", "us-west-2")
|
|
AGENT_ARN = os.getenv("AGENT_ARN")
|
|
|
|
if not AGENT_ARN:
|
|
logger.error("AGENT_ARN environment variable is not set")
|
|
raise ValueError("AGENT_ARN environment variable is required")
|
|
|
|
# Models
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class ChatRequest(BaseModel):
|
|
messages: List[Message]
|
|
|
|
# Connection manager for WebSockets
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: Dict[str, WebSocket] = {}
|
|
self.session_ids: Dict[str, str] = {} # Map client_id to runtime_session_id
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str):
|
|
await websocket.accept()
|
|
self.active_connections[client_id] = websocket
|
|
self.session_ids[client_id] = None # Initialize with no session
|
|
|
|
def disconnect(self, client_id: str):
|
|
if client_id in self.active_connections:
|
|
del self.active_connections[client_id]
|
|
if client_id in self.session_ids:
|
|
del self.session_ids[client_id]
|
|
|
|
async def send_message(self, message: str, client_id: str):
|
|
if client_id in self.active_connections:
|
|
await self.active_connections[client_id].send_text(message)
|
|
|
|
def get_session_id(self, client_id: str) -> Optional[str]:
|
|
return self.session_ids.get(client_id)
|
|
|
|
def set_session_id(self, client_id: str, session_id: str):
|
|
self.session_ids[client_id] = session_id
|
|
|
|
manager = ConnectionManager()
|
|
|
|
def parse_streaming_response(content):
|
|
"""Parse streaming response content to extract the final response text"""
|
|
try:
|
|
logger.debug(f"Parsing streaming content: {len(content)} characters")
|
|
|
|
# Split content by lines and look for the final response
|
|
lines = content.strip().split('\n')
|
|
final_response = ""
|
|
accumulated_text = ""
|
|
|
|
# Look for the final complete message at the end
|
|
for line in reversed(lines):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
# Try to parse as JSON directly (for AgentCore responses)
|
|
if line.startswith('{') and line.endswith('}'):
|
|
json_data = json.loads(line)
|
|
# Handle SSE format with 'data: ' prefix
|
|
elif line.startswith('data: '):
|
|
json_str = line[6:].strip()
|
|
if not json_str:
|
|
continue
|
|
json_data = json.loads(json_str)
|
|
else:
|
|
continue
|
|
|
|
# Look for the final complete response first
|
|
if isinstance(json_data, dict):
|
|
# Check for complete type with final_response (highest priority)
|
|
if json_data.get('type') == 'complete' and 'final_response' in json_data:
|
|
final_response = json_data['final_response']
|
|
logger.debug("Found complete type with final_response")
|
|
break
|
|
|
|
# Check for message content with complete response
|
|
elif 'message' in json_data:
|
|
message = json_data['message']
|
|
if isinstance(message, dict) and 'content' in message:
|
|
content_list = message['content']
|
|
if isinstance(content_list, list):
|
|
text_parts = []
|
|
for item in content_list:
|
|
if isinstance(item, dict) and 'text' in item:
|
|
text_parts.append(item['text'])
|
|
if text_parts:
|
|
candidate_response = ' '.join(text_parts)
|
|
# Only use if it's a substantial response (likely the final one)
|
|
if len(candidate_response) > 200:
|
|
final_response = candidate_response
|
|
logger.debug("Found substantial message content")
|
|
break
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.debug(f"Failed to parse JSON line: {line[:100]}... Error: {e}")
|
|
continue
|
|
except Exception as e:
|
|
logger.debug(f"Error processing line: {e}")
|
|
continue
|
|
|
|
# If no final response found, accumulate text chunks
|
|
if not final_response:
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
if line.startswith('{') and line.endswith('}'):
|
|
json_data = json.loads(line)
|
|
elif line.startswith('data: '):
|
|
json_str = line[6:].strip()
|
|
if not json_str:
|
|
continue
|
|
json_data = json.loads(json_str)
|
|
else:
|
|
continue
|
|
|
|
if isinstance(json_data, dict):
|
|
# Check for streaming text chunks
|
|
if 'event' in json_data:
|
|
event = json_data['event']
|
|
if isinstance(event, dict):
|
|
# Handle contentBlockDelta events
|
|
if 'contentBlockDelta' in event:
|
|
delta = event['contentBlockDelta']
|
|
if isinstance(delta, dict) and 'delta' in delta:
|
|
delta_data = delta['delta']
|
|
if isinstance(delta_data, dict) and 'text' in delta_data:
|
|
accumulated_text += delta_data['text']
|
|
|
|
# Check for chunk data
|
|
elif 'data' in json_data and isinstance(json_data['data'], str):
|
|
accumulated_text += json_data['data']
|
|
|
|
except json.JSONDecodeError:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
# Return the best response we found
|
|
if final_response:
|
|
logger.info(f"Extracted final response: {len(final_response)} characters")
|
|
return final_response
|
|
elif accumulated_text:
|
|
logger.info(f"Using accumulated text: {len(accumulated_text)} characters")
|
|
return accumulated_text
|
|
else:
|
|
logger.warning("No response text found in streaming data")
|
|
return f"No parseable response found. Raw content sample: {content[:500]}..."
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error parsing streaming response: {str(e)}")
|
|
return f"Error parsing response: {str(e)}"
|
|
|
|
def format_response_text(text):
|
|
"""Format response text for better readability in the UI"""
|
|
if not text:
|
|
return ""
|
|
|
|
try:
|
|
# Clean up the text first
|
|
text = text.strip()
|
|
|
|
# Try to parse as JSON if it looks like JSON
|
|
if (text.startswith('{') and text.endswith('}')) or \
|
|
(text.startswith('[') and text.endswith(']')):
|
|
try:
|
|
parsed = json.loads(text)
|
|
|
|
# If it's a list of devices, format it nicely
|
|
if isinstance(parsed, list) and len(parsed) > 0 and isinstance(parsed[0], dict):
|
|
# Check if this looks like a device list
|
|
if all('device_id' in item or 'name' in item for item in parsed):
|
|
result = "📱 **Device List:**\n\n"
|
|
for i, item in enumerate(parsed, 1):
|
|
name = item.get('name', 'Unknown Device')
|
|
device_id = item.get('device_id', item.get('id', 'Unknown ID'))
|
|
status = item.get('connection_status', item.get('status', 'Unknown'))
|
|
|
|
# Add status emoji
|
|
status_emoji = {
|
|
'Connected': '🟢',
|
|
'Disconnected': '🔴',
|
|
'Updating': '🟡',
|
|
'Dormant': '🟠',
|
|
'Maintenance': '🔧'
|
|
}.get(status, '⚪')
|
|
|
|
result += f"**{i}. {name}** {status_emoji}\n"
|
|
result += f" • ID: `{device_id}`\n"
|
|
|
|
if 'model' in item:
|
|
result += f" • Model: {item['model']}\n"
|
|
if 'ip_address' in item:
|
|
result += f" • IP: {item['ip_address']}\n"
|
|
if 'connection_status' in item:
|
|
result += f" • Status: {item['connection_status']}\n"
|
|
if 'firmware_version' in item:
|
|
result += f" • Firmware: {item['firmware_version']}\n"
|
|
if 'last_connected' in item:
|
|
# Format the timestamp nicely
|
|
timestamp = item['last_connected']
|
|
if 'T' in timestamp:
|
|
date_part = timestamp.split('T')[0]
|
|
time_part = timestamp.split('T')[1].split('.')[0]
|
|
result += f" • Last Connected: {date_part} at {time_part}\n"
|
|
else:
|
|
result += f" • Last Connected: {timestamp}\n"
|
|
|
|
result += "\n"
|
|
|
|
return result.strip()
|
|
|
|
# For other JSON data, pretty print with indentation
|
|
return f"```json\n{json.dumps(parsed, indent=2)}\n```"
|
|
|
|
except json.JSONDecodeError:
|
|
# Not valid JSON, continue with regular formatting
|
|
pass
|
|
|
|
# Replace escaped characters
|
|
text = text.replace('\\n', '\n').replace('\\"', '"').replace("\\'", "'")
|
|
|
|
# Format bullet points consistently
|
|
lines = text.split('\n')
|
|
formatted_lines = []
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
formatted_lines.append('')
|
|
continue
|
|
|
|
# Convert numbered lists to bullet points
|
|
if line and len(line) > 2 and line[0].isdigit() and line[1:3] in ['. ', ') ']:
|
|
line = '• ' + line.split('. ', 1)[1] if '. ' in line else '• ' + line.split(') ', 1)[1]
|
|
|
|
# Ensure consistent bullet formatting
|
|
elif line.startswith('- '):
|
|
line = '• ' + line[2:]
|
|
|
|
# Format key-value pairs nicely
|
|
elif ':' in line and not line.startswith(' ') and not line.startswith('•'):
|
|
parts = line.split(':', 1)
|
|
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
|
|
key = parts[0].strip()
|
|
value = parts[1].strip()
|
|
line = f"**{key}:** {value}"
|
|
|
|
formatted_lines.append(line)
|
|
|
|
result = '\n'.join(formatted_lines)
|
|
|
|
# Clean up excessive whitespace
|
|
while '\n\n\n' in result:
|
|
result = result.replace('\n\n\n', '\n\n')
|
|
|
|
return result.strip()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error formatting response: {str(e)}")
|
|
return text # Return original text if formatting fails
|
|
|
|
def create_agentcore_client(auth_token=None):
|
|
"""Create AgentCore client and boto session"""
|
|
# Create boto session
|
|
boto_session = boto3.Session(region_name=AWS_REGION)
|
|
|
|
# Create bedrock-agentcore client
|
|
agentcore_client = boto_session.client(
|
|
'bedrock-agentcore',
|
|
region_name=AWS_REGION
|
|
)
|
|
|
|
return agentcore_client
|
|
|
|
# Routes
|
|
@app.websocket("/ws/{client_id}")
|
|
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
|
"""WebSocket endpoint for real-time chat with streaming support"""
|
|
await manager.connect(websocket, client_id)
|
|
|
|
try:
|
|
# Create AgentCore client
|
|
agentcore_client = create_agentcore_client()
|
|
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
user_message = data.strip()
|
|
|
|
if not user_message:
|
|
await manager.send_message(json.dumps({"error": "Empty message"}), client_id)
|
|
continue
|
|
|
|
try:
|
|
# Get current session ID for this client
|
|
session_id = manager.get_session_id(client_id)
|
|
|
|
# Invoke the agent with retry logic
|
|
import time
|
|
from botocore.exceptions import ClientError
|
|
|
|
max_retries = 3
|
|
retry_delay = 1 # Start with 1 second delay
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
if session_id is None:
|
|
# First message in conversation
|
|
logger.info("Starting new conversation with streaming")
|
|
boto3_response = agentcore_client.invoke_agent_runtime(
|
|
agentRuntimeArn=AGENT_ARN,
|
|
qualifier="DEFAULT",
|
|
payload=json.dumps({"prompt": user_message})
|
|
)
|
|
else:
|
|
# Continuing conversation with existing session ID
|
|
logger.info(f"Continuing conversation with session ID: {session_id}")
|
|
boto3_response = agentcore_client.invoke_agent_runtime(
|
|
agentRuntimeArn=AGENT_ARN,
|
|
qualifier="DEFAULT",
|
|
payload=json.dumps({"prompt": user_message}),
|
|
runtimeSessionId=session_id
|
|
)
|
|
# If successful, break out of retry loop
|
|
break
|
|
except ClientError as e:
|
|
if e.response['Error']['Code'] == 'throttlingException' and attempt < max_retries - 1:
|
|
logger.warning(f"Throttling exception encountered. Retrying in {retry_delay} seconds...")
|
|
await manager.send_message(json.dumps({"status": f"Rate limited. Retrying in {retry_delay} seconds..."}), client_id)
|
|
time.sleep(retry_delay)
|
|
# Exponential backoff
|
|
retry_delay *= 2
|
|
else:
|
|
# Re-raise the exception if we've exhausted retries or it's not a throttling exception
|
|
raise
|
|
|
|
# Update session ID
|
|
if isinstance(boto3_response, dict) and 'runtimeSessionId' in boto3_response:
|
|
new_session_id = boto3_response['runtimeSessionId']
|
|
logger.info(f"Received new session ID: {new_session_id}")
|
|
manager.set_session_id(client_id, new_session_id)
|
|
else:
|
|
logger.warning("No runtimeSessionId in response")
|
|
# Keep using the existing session ID if available
|
|
new_session_id = session_id
|
|
|
|
# Process streaming response
|
|
full_response = ""
|
|
|
|
# Handle streaming response from AgentCore
|
|
if isinstance(boto3_response, dict) and "response" in boto3_response:
|
|
try:
|
|
response_stream = boto3_response["response"]
|
|
logger.info(f"Processing streaming response, type: {type(response_stream)}")
|
|
|
|
# Handle StreamingBody properly
|
|
if hasattr(response_stream, 'read'):
|
|
content = response_stream.read()
|
|
if isinstance(content, bytes):
|
|
content = content.decode('utf-8')
|
|
|
|
logger.debug(f"Raw streaming content received: {len(content)} characters")
|
|
|
|
# Parse the streaming content to extract the final response
|
|
final_response_text = parse_streaming_response(content)
|
|
|
|
if final_response_text:
|
|
# Send final complete message
|
|
await manager.send_message(json.dumps({
|
|
"response": format_response_text(final_response_text),
|
|
"sessionId": new_session_id,
|
|
"complete": True
|
|
}), client_id)
|
|
else:
|
|
await manager.send_message(json.dumps({
|
|
"error": "No valid response content found in streaming data"
|
|
}), client_id)
|
|
|
|
else:
|
|
# Fallback: convert to string
|
|
content = str(response_stream)
|
|
final_response_text = parse_streaming_response(content)
|
|
|
|
if final_response_text:
|
|
await manager.send_message(json.dumps({
|
|
"response": format_response_text(final_response_text),
|
|
"sessionId": new_session_id,
|
|
"complete": True
|
|
}), client_id)
|
|
else:
|
|
await manager.send_message(json.dumps({
|
|
"error": "No valid response content found"
|
|
}), client_id)
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error processing streaming response: {str(e)}')
|
|
await manager.send_message(json.dumps({
|
|
'error': f'Error processing streaming response: {str(e)}'
|
|
}), client_id)
|
|
else:
|
|
# Fallback to non-streaming response handling
|
|
logger.warning('No streaming response found, falling back to non-streaming')
|
|
response_content = str(boto3_response)
|
|
formatted_response = format_response_text(response_content)
|
|
|
|
await manager.send_message(json.dumps({
|
|
'response': formatted_response,
|
|
'sessionId': new_session_id
|
|
}), client_id)
|
|
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
logger.error(f'Error processing request with agent: {error_message}')
|
|
|
|
# Provide more helpful error messages for common issues
|
|
if 'throttlingException' in error_message:
|
|
error_message = 'Too many requests. The service is temporarily throttling requests. Please try again in a few moments.'
|
|
elif 'AccessDeniedException' in error_message:
|
|
error_message = 'Access denied. Please check your AWS credentials and permissions.'
|
|
elif 'ValidationException' in error_message and 'runtimeSessionId' in error_message:
|
|
error_message = 'Invalid session ID. Starting a new conversation.'
|
|
manager.set_session_id(client_id, None) # Reset the session ID
|
|
|
|
await manager.send_message(json.dumps({
|
|
'error': f'Error processing request with agent: {error_message}'
|
|
}), client_id)
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(client_id)
|
|
except Exception as e:
|
|
logger.error(f'WebSocket error: {str(e)}')
|
|
manager.disconnect(client_id)
|
|
|
|
# Authentication routes
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def root(request: Request):
|
|
"""Root endpoint that redirects to login if not authenticated or chat if authenticated"""
|
|
user = await get_current_user(request)
|
|
if not user:
|
|
# Try simple login first as a fallback
|
|
return RedirectResponse(url="/simple-login")
|
|
return templates.TemplateResponse("chat.html", {"request": request, "user": user})
|
|
|
|
@app.get("/login", response_class=HTMLResponse)
|
|
async def login_page(request: Request):
|
|
"""Login page with Cognito authentication"""
|
|
user = await get_current_user(request)
|
|
if user:
|
|
return RedirectResponse(url="/")
|
|
|
|
login_url = get_login_url()
|
|
return templates.TemplateResponse("login.html", {"request": request, "login_url": login_url})
|
|
|
|
@app.get("/auth/callback")
|
|
async def auth_callback(
|
|
request: Request,
|
|
code: Optional[str] = None,
|
|
error: Optional[str] = None,
|
|
error_description: Optional[str] = None
|
|
):
|
|
"""Callback endpoint for Cognito authentication"""
|
|
# Check if there was an error in the authentication process
|
|
if error:
|
|
error_msg = f"Authentication error: {error}"
|
|
if error_description:
|
|
error_msg += f" - {error_description}"
|
|
logger.error(error_msg)
|
|
return templates.TemplateResponse(
|
|
"login.html",
|
|
{
|
|
"request": request,
|
|
"login_url": get_login_url(),
|
|
"error": error_msg
|
|
},
|
|
status_code=400
|
|
)
|
|
|
|
# If no code is provided, return an error
|
|
if not code:
|
|
logger.error("No authorization code provided")
|
|
return templates.TemplateResponse(
|
|
"login.html",
|
|
{
|
|
"request": request,
|
|
"login_url": get_login_url(),
|
|
"error": "No authorization code provided"
|
|
},
|
|
status_code=400
|
|
)
|
|
|
|
try:
|
|
# Exchange the authorization code for tokens
|
|
tokens = await exchange_code_for_tokens(code)
|
|
|
|
# Validate the ID token
|
|
id_token = tokens["id_token"]
|
|
claims = await validate_token(id_token)
|
|
|
|
# Store user info in session
|
|
request.session["user"] = {
|
|
"sub": claims["sub"],
|
|
"email": claims.get("email", ""),
|
|
"name": claims.get("name", ""),
|
|
"access_token": tokens["access_token"],
|
|
"id_token": id_token
|
|
}
|
|
|
|
# Redirect to the main page
|
|
return RedirectResponse(url="/")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Authentication error: {str(e)}")
|
|
return templates.TemplateResponse(
|
|
"login.html",
|
|
{
|
|
"request": request,
|
|
"login_url": get_login_url(),
|
|
"error": str(e)
|
|
},
|
|
status_code=400
|
|
)
|
|
|
|
@app.get("/logout")
|
|
async def logout(request: Request):
|
|
"""Logout endpoint"""
|
|
# Clear the session
|
|
request.session.clear()
|
|
|
|
# Create response and clear all authentication cookies
|
|
response = RedirectResponse(url="/simple-login")
|
|
response.delete_cookie("access_token")
|
|
response.delete_cookie("simple_user")
|
|
|
|
return response
|
|
|
|
@app.get("/profile")
|
|
async def profile(request: Request, user: dict = Depends(login_required)):
|
|
"""User profile endpoint"""
|
|
return {"user": user}
|
|
|
|
@app.get("/simple-login", response_class=HTMLResponse)
|
|
async def simple_login_page(request: Request):
|
|
"""Simple login page without Cognito"""
|
|
user = await get_current_user(request)
|
|
if user:
|
|
return RedirectResponse(url="/")
|
|
|
|
return templates.TemplateResponse("simple_login.html", {"request": request})
|
|
|
|
@app.post("/simple-login")
|
|
async def simple_login_submit(request: Request, username: str = Form(...), password: str = Form(...)):
|
|
"""Process simple login form"""
|
|
# For demo purposes, accept any username/password
|
|
# In a real application, you would validate against a database or other authentication system
|
|
|
|
# Store user info in session
|
|
request.session["user"] = {
|
|
"sub": "simple-user-123",
|
|
"email": username,
|
|
"name": username,
|
|
"access_token": "demo-token",
|
|
"id_token": "demo-token"
|
|
}
|
|
|
|
# Redirect to the main page
|
|
return RedirectResponse(url="/", status_code=303) # 303 See Other is used for POST redirects
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|