Godwin Vincent cd0a29d2ae
Device management agent - AgentCore runtime, observability, frontend added (#241)
* updated README.md file with bearer token generation

* updated README.md file with bearer token generation-removed client id and secret credentials

* removed hardcoded domain

* added agent runtime, frontend, observability and agentcore identity

* update README.md file to reflect frontend testing
2025-08-13 09:31:29 -07:00

731 lines
30 KiB
Python

"""
Device Remote Management AI Agent for Amazon Bedrock AgentCore Runtime
This version is adapted to work with Amazon Bedrock AgentCore Runtime
"""
import os
import json
import logging
import requests
import asyncio
from dotenv import load_dotenv
import utils
import access_token
# Import Strands Agents SDK
from strands import Agent
from strands.models import BedrockModel
from strands.agent.conversation_manager import SlidingWindowConversationManager
from mcp.client.streamable_http import streamablehttp_client
from strands.tools.mcp import MCPClient
from bedrock_agentcore.runtime import BedrockAgentCoreApp
# Load environment variables
load_dotenv()
# Initialize the AgentCore Runtime App
app = BedrockAgentCoreApp()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Set logging level for specific libraries
logging.getLogger('requests').setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('mcp').setLevel(logging.INFO)
logging.getLogger('strands').setLevel(logging.INFO)
# MCP Server configuration
MCP_SERVER_URL = os.getenv("MCP_SERVER_URL")
logger.info(f"MCP_SERVER_URL set to: {MCP_SERVER_URL}")
# Configure conversation management for production
conversation_manager = SlidingWindowConversationManager(
window_size=25, # Limit history size
)
# Function to check if MCP server is running
def check_mcp_server():
try:
# Get the bearer token
jwt_token = os.getenv("BEARER_TOKEN")
logger.info(f"Checking MCP server at URL: {MCP_SERVER_URL}")
# If no bearer token, try to get one from Cognito
if not jwt_token:
logger.info("No bearer token available, trying to get one from Cognito...")
try:
jwt_token = access_token.get_gateway_access_token()
logger.info(f"Retrieved token: {jwt_token}")
logger.info(f"Cognito token obtained: {'Yes' if jwt_token else 'No'}")
except Exception as e:
logger.error(f"Error getting Cognito token: {str(e)}", exc_info=True)
if jwt_token:
headers = {"Authorization": f"Bearer {jwt_token}", "Content-Type": "application/json"}
payload = {
"jsonrpc": "2.0",
"id": "test",
"method": "tools/list",
"params": {}
}
try:
response = requests.post(f"{MCP_SERVER_URL}/mcp", headers=headers, json=payload, timeout=10)
logger.info(f"MCP server response status: {response.status_code}")
has_tools = "tools" in response.text
return has_tools
except requests.exceptions.RequestException as e:
logger.error(f"Request exception when checking MCP server: {str(e)}")
return False
else:
# Try without token for local testing
logger.info("No bearer token available, trying health endpoint")
try:
response = requests.get(f"{MCP_SERVER_URL}/health", timeout=5)
logger.info(f"Health endpoint response status: {response.status_code}")
return response.status_code == 200
except requests.exceptions.RequestException as e:
logger.error(f"Health endpoint request exception: {str(e)}")
return False
except Exception as e:
logger.error(f"Error checking MCP server: {str(e)}", exc_info=True)
return False
# Initialize Strands Agent with MCP tools
def initialize_agent():
try:
# Get OAuth token for authentication
logger.info("Starting agent initialization...")
# First try to get token from environment variable (for Docker)
jwt_token = os.getenv("BEARER_TOKEN")
# If not available in environment, try to get from Cognito
if not jwt_token:
logger.info("No token in environment, trying Cognito...")
try:
#jwt_token = asyncio.run(access_token.get_gateway_access_token())
jwt_token = access_token.get_gateway_access_token()
logger.info(f"Retrieved token: {jwt_token}")
except Exception as e:
logger.error(f"Error getting Cognito token: {str(e)}", exc_info=True)
# Create MCP client with authentication headers
gateway_endpoint = os.getenv("gateway_endpoint", MCP_SERVER_URL)
logger.info(f"Using gateway endpoint: {gateway_endpoint}")
# Try to resolve the hostname to check network connectivity
try:
import socket
hostname = gateway_endpoint.split("://")[1].split("/")[0]
ip_address = socket.gethostbyname(hostname)
logger.info(f"Hostname resolved to IP: {ip_address}")
except Exception as e:
logger.error(f"Error resolving hostname: {str(e)}")
headers = {"Authorization": f"Bearer {jwt_token}"} if jwt_token else {}
try:
logger.info("Creating MCP client...")
# Create the MCP client
mcp_client = MCPClient(lambda: streamablehttp_client(
url = f"{gateway_endpoint}/mcp",
headers=headers
))
logger.info("MCP Client setup complete")
# Enter the context manager
mcp_client.__enter__()
# Get the tools from the MCP server
logger.info("Listing tools from MCP server...")
tools = mcp_client.list_tools_sync()
logger.info(f"Loaded {len(tools)} tools from MCP server")
# Log available tools
if tools and len(tools) > 0:
# Try to access the tool name using the correct attribute
tool_names = []
for tool in tools:
# Check if the tool has a 'schema' attribute that might contain the name
if hasattr(tool, 'schema') and hasattr(tool.schema, 'name'):
tool_names.append(tool.schema.name)
# Or if it has a direct attribute that contains the name
elif hasattr(tool, 'tool_name'):
tool_names.append(tool.tool_name)
# Or if it's in the __dict__
elif '_name' in vars(tool):
tool_names.append(vars(tool)['_name'])
else:
# If we can't find the name, use a placeholder
tool_names.append(f"Tool-{id(tool)}")
logger.info(f"Available tools: {', '.join(tool_names)}")
except Exception as e:
logger.error(f"Error setting up MCP client: {str(e)}", exc_info=True)
return None, None
# Create an agent with these tools
try:
logger.info("Creating Strands Agent with tools...")
model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" # Using Claude Sonnet
model = BedrockModel(model_id=model_id)
agent = Agent(
model=model,
tools=tools,
conversation_manager=conversation_manager,
system_prompt="""
You are an AI assistant for Device Remote Management. Help the user with their query.
You have access to tools that can retrieve real data from the Device Remote Management system.
Available tools:
- list_devices: List all devices in the system
- get_device_settings: Get settings for a specific device
- list_wifi_networks: List all WiFi networks for a specific device
- list_users: List all users in the system
- query_user_activity: Query user activity within a time period
- update_wifi_ssid: Update the SSID of a Wi-Fi network on a device
- update_wifi_security: Update the security type of a Wi-Fi network on a device
Use these tools to help users manage their Device devices and accounts.
"""
)
logger.info("Agent created successfully")
return agent, mcp_client
except Exception as e:
logger.error(f"Error creating agent: {str(e)}", exc_info=True)
return None, None
except Exception as e:
logger.error(f"Error initializing agent: {str(e)}", exc_info=True)
return None, None
# Initialize the agent if MCP server is running
agent = None
mcp_client = None
if check_mcp_server():
agent, mcp_client = initialize_agent()
if agent:
logger.info("Agent initialized successfully")
else:
logger.warning("Failed to initialize agent")
else:
logger.warning("MCP server is not running. Agent initialization skipped.")
# Function to format response for better display
def format_response(text):
"""Format the response for better display using plain text formatting"""
if not isinstance(text, str):
text = str(text)
# Check if the response contains JSON data
json_start = text.find('{')
json_end = text.rfind('}')
if json_start >= 0 and json_end > json_start:
try:
json_str = text[json_start:json_end+1]
json_data = json.loads(json_str)
# Format based on the type of data
if isinstance(json_data, list) and len(json_data) > 0:
# Check if it's a list of devices
if isinstance(json_data[0], dict) and all(key in json_data[0] for key in ['device_id', 'model']):
formatted_text = format_device_list(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Check if it's a list of users
elif isinstance(json_data[0], dict) and all(key in json_data[0] for key in ['user_id', 'username']):
formatted_text = format_user_list(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Check if it's a list of activities
elif isinstance(json_data[0], dict) and all(key in json_data[0] for key in ['user_id', 'activity_type']):
formatted_text = format_activity_list(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Generic list of objects
elif isinstance(json_data[0], dict):
formatted_text = format_generic_list(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Single object with wifi_networks
elif isinstance(json_data, dict) and 'wifi_networks' in json_data:
formatted_text = format_wifi_networks(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Single object
elif isinstance(json_data, dict):
# Check if it's device settings
if 'device_id' in json_data and 'settings' in json_data:
formatted_text = format_device_settings(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Check if it's a WiFi update result
elif 'device_id' in json_data and 'network_id' in json_data and 'old_ssid' in json_data:
formatted_text = format_wifi_update(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Check if it's a WiFi security update result
elif 'device_id' in json_data and 'network_id' in json_data and 'old_security_type' in json_data:
formatted_text = format_wifi_security_update(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
# Generic object
else:
formatted_text = format_generic_object(json_data)
return text[:json_start] + formatted_text + text[json_end+1:]
except Exception as e:
logger.debug(f"Error formatting JSON: {str(e)}")
# If no JSON formatting was applied, return the original text
return text
def format_device_list(devices):
"""Format a list of devices using plain text formatting"""
if not devices or not isinstance(devices, list):
return "No devices found."
result = "Devices in Device Remote Management System\n\n"
# Add a header row
result += "Name | Device ID | Model | Status | IP Address | Last Connected\n"
result += "----------------------|------------|-----------|------------|-----------------|---------------\n"
# Add each device
for device in devices:
name = device.get('name', 'N/A')
device_id = device.get('device_id', 'N/A')
model = device.get('model', 'N/A')
status = device.get('connection_status', 'N/A')
ip = device.get('ip_address', 'N/A')
last_connected = device.get('last_connected', 'N/A')
# Format the row with fixed column widths
result += f"{name[:20].ljust(20)} | {device_id[:10].ljust(10)} | {model[:9].ljust(9)} | {status[:10].ljust(10)} | {ip[:15].ljust(15)} | {str(last_connected)[:15]}\n"
return result
def format_user_list(users):
"""Format a list of users using plain text formatting"""
if not users or not isinstance(users, list):
return "No users found."
result = "Users in Device Remote Management System\n\n"
# Add a header row
result += "Username | User ID | Email | Role | Last Login\n"
result += "----------------------|------------|--------------------------------|------------|---------------\n"
# Add each user
for user in users:
username = user.get('username', 'N/A')
user_id = user.get('user_id', 'N/A')
email = user.get('email', 'N/A')
role = user.get('role', 'N/A')
last_login = user.get('last_login', 'N/A')
# Format the row with fixed column widths
result += f"{username[:20].ljust(20)} | {user_id[:10].ljust(10)} | {email[:30].ljust(30)} | {role[:10].ljust(10)} | {str(last_login)[:15]}\n"
return result
def format_activity_list(activities):
"""Format a list of user activities using plain text formatting"""
if not activities or not isinstance(activities, list):
return "No activities found."
result = "User Activities in Device Remote Management System\n\n"
# Add a header row
result += "Username | Activity Type | Description | Timestamp | IP Address\n"
result += "----------------------|---------------------|-------------------------------|---------------------|---------------\n"
# Add each activity
for activity in activities:
username = activity.get('username', 'N/A')
activity_type = activity.get('activity_type', 'N/A')
description = activity.get('description', 'N/A')
timestamp = activity.get('timestamp', 'N/A')
ip_address = activity.get('ip_address', 'N/A')
# Format the row with fixed column widths
result += f"{username[:20].ljust(20)} | {activity_type[:19].ljust(19)} | {description[:30].ljust(30)} | {str(timestamp)[:19].ljust(19)} | {ip_address[:15]}\n"
return result
def format_wifi_networks(data):
"""Format WiFi networks list using plain text formatting"""
if not data or not isinstance(data, dict) or 'wifi_networks' not in data:
return "No WiFi networks found."
device_id = data.get('device_id', 'Unknown')
device_name = data.get('device_name', 'Unknown Device')
networks = data.get('wifi_networks', [])
if not networks:
return f"No WiFi networks found for device {device_name} ({device_id})."
result = f"WiFi Networks for Device: {device_name} ({device_id})\n\n"
# Add a header row
result += "SSID | Network ID | Security | Enabled | Channel | Signal\n"
result += "----------------------------------|------------|------------|---------|---------|-------\n"
# Add each network
for network in networks:
ssid = network.get('ssid', 'N/A')
network_id = network.get('network_id', 'N/A')
security = network.get('security_type', 'N/A')
enabled = network.get('enabled', 'N/A')
channel = network.get('channel', 'N/A')
signal = network.get('signal_strength', 'N/A')
# Format the row with fixed column widths
result += f"{ssid[:34].ljust(34)} | {str(network_id)[:10].ljust(10)} | {str(security)[:10].ljust(10)} | {str(enabled)[:7].ljust(7)} | {str(channel)[:7].ljust(7)} | {str(signal)[:7]}\n"
return result
def format_device_settings(settings):
"""Format device settings using plain text formatting"""
if not settings or not isinstance(settings, dict):
return "No settings found."
device_name = settings.get('device_name', 'Unknown Device')
device_id = settings.get('device_id', 'N/A')
model = settings.get('model', 'N/A')
firmware = settings.get('firmware_version', 'N/A')
status = settings.get('connection_status', 'N/A')
result = f"Settings for Device: {device_name} ({device_id})\n\n"
# Device information section
result += "DEVICE INFORMATION\n"
result += "=================\n"
result += f"Model: {model}\n"
result += f"Firmware: {firmware}\n"
result += f"Status: {status}\n\n"
# Configuration settings section
if 'settings' in settings and settings['settings']:
result += "CONFIGURATION SETTINGS\n"
result += "=====================\n"
# Get the maximum key length for alignment
max_key_length = max([len(key) for key in settings['settings'].keys()])
for key, value in sorted(settings['settings'].items()):
result += f"{key.ljust(max_key_length + 2)}: {value}\n"
return result
def format_wifi_update(result):
"""Format WiFi SSID update result using plain text formatting"""
if not result or not isinstance(result, dict):
return "No update result available."
if 'error' in result:
return f"Error updating WiFi SSID: {result['error']}"
device_name = result.get('device_name', 'Unknown Device')
device_id = result.get('device_id', 'N/A')
network_id = result.get('network_id', 'N/A')
old_ssid = result.get('old_ssid', 'N/A')
new_ssid = result.get('new_ssid', 'N/A')
status = result.get('status', 'N/A')
output = "WiFi SSID Update Successful\n"
output += "==========================\n\n"
output += f"Device: {device_name} (ID: {device_id})\n"
output += f"Network ID: {network_id}\n"
output += f"Previous SSID: {old_ssid}\n"
output += f"New SSID: {new_ssid}\n"
output += f"Status: {status}\n"
return output
def format_wifi_security_update(result):
"""Format WiFi security type update result using plain text formatting"""
if not result or not isinstance(result, dict):
return "No update result available."
if 'error' in result:
return f"Error updating WiFi security type: {result['error']}"
device_name = result.get('device_name', 'Unknown Device')
device_id = result.get('device_id', 'N/A')
network_id = result.get('network_id', 'N/A')
ssid = result.get('ssid', 'N/A')
old_security_type = result.get('old_security_type', 'N/A')
new_security_type = result.get('new_security_type', 'N/A')
status = result.get('status', 'N/A')
output = "WiFi Security Type Update Successful\n"
output += "==================================\n\n"
output += f"Device: {device_name} (ID: {device_id})\n"
output += f"Network ID: {network_id}\n"
output += f"SSID: {ssid}\n"
output += f"Previous Security: {old_security_type}\n"
output += f"New Security: {new_security_type}\n"
output += f"Status: {status}\n"
return output
def format_generic_list(items):
"""Format a generic list of objects using plain text formatting"""
if not items or not isinstance(items, list):
return "No items found."
# Get all unique keys from all items
all_keys = set()
for item in items:
if isinstance(item, dict):
all_keys.update(item.keys())
# If no keys found, just return a simple list
if not all_keys:
result = "Results:\n\n"
for i, item in enumerate(items, 1):
result += f"{i}. {item}\n"
return result
# Convert keys to a sorted list for consistent column order
keys = sorted(all_keys)
# Create a header row with column names
result = "Results:\n\n"
header = ""
separator = ""
# Calculate column widths (minimum 10 characters, maximum 20)
col_widths = {}
for key in keys:
# Find the maximum length of values for this key
max_val_length = max([len(str(item.get(key, ''))) for item in items if isinstance(item, dict)] + [len(key)])
col_widths[key] = min(max(max_val_length, 10), 20)
# Create header and separator
for key in keys:
width = col_widths[key]
header += f"{key[:width].ljust(width)} | "
separator += "-" * width + "-|-"
# Remove trailing separator characters
header = header[:-3]
separator = separator[:-3]
result += header + "\n" + separator + "\n"
# Add rows
for item in items:
if not isinstance(item, dict):
continue
row = ""
for key in keys:
width = col_widths[key]
value = str(item.get(key, 'N/A'))
row += f"{value[:width].ljust(width)} | "
# Remove trailing separator characters
row = row[:-3]
result += row + "\n"
return result
def format_generic_object(obj):
"""Format a generic object using plain text formatting"""
if not obj or not isinstance(obj, dict):
return "No data available."
if 'error' in obj:
return f"Error: {obj['error']}"
result = "Result:\n\n"
# Get the maximum key length for alignment
max_key_length = max([len(key) for key in obj.keys()])
for key, value in sorted(obj.items()):
if isinstance(value, dict):
# Nested object
result += f"{key}:\n"
result += "-" * len(key) + "\n"
# Get the maximum nested key length for alignment
nested_max_key_length = max([len(k) for k in value.keys()]) if value else 0
for k, v in sorted(value.items()):
result += f" {k.ljust(nested_max_key_length + 2)}: {v}\n"
result += "\n"
elif isinstance(value, list):
# List of values
result += f"{key}:\n"
result += "-" * len(key) + "\n"
if all(isinstance(item, dict) for item in value):
# List of objects - get all unique keys
all_keys = set()
for item in value:
all_keys.update(item.keys())
# Calculate column widths
col_widths = {}
for k in all_keys:
max_val_length = max([len(str(item.get(k, ''))) for item in value] + [len(k)])
col_widths[k] = min(max(max_val_length, 10), 20)
# Create header
header = " "
separator = " "
for k in sorted(all_keys):
width = col_widths[k]
header += f"{k[:width].ljust(width)} | "
separator += "-" * width + "-|-"
# Remove trailing separator characters
header = header[:-3]
separator = separator[:-3]
result += header + "\n" + separator + "\n"
# Add rows
for item in value:
row = " "
for k in sorted(all_keys):
width = col_widths[k]
v = str(item.get(k, 'N/A'))
row += f"{v[:width].ljust(width)} | "
# Remove trailing separator characters
row = row[:-3]
result += row + "\n"
else:
# Simple list
for i, item in enumerate(value, 1):
result += f" {i}. {item}\n"
result += "\n"
else:
# Simple value
result += f"{key.ljust(max_key_length + 2)}: {value}\n"
return result
@app.entrypoint
async def process_request(payload):
"""
Process requests from AgentCore Runtime with streaming support
This is the entry point for the AgentCore Runtime
"""
global agent, mcp_client
try:
# Extract the user message from the payload
user_message = payload.get("prompt", "No prompt found in input, please provide a message")
logger.info(f"Received user message: {user_message}")
# Check if agent is initialized
if not agent:
logger.info("Agent not initialized, checking MCP server status...")
# Try to initialize the agent if MCP server is running
if check_mcp_server():
logger.info("MCP server is running, attempting to initialize agent...")
agent, mcp_client = initialize_agent()
if not agent:
error_msg = "Failed to initialize agent. Please ensure MCP server is running correctly."
logger.error(error_msg)
yield {"error": error_msg}
return
logger.info("Agent initialized successfully")
else:
error_msg = "Agent is not initialized. Please ensure MCP server is running."
logger.error(error_msg)
yield {"error": error_msg}
return
# Use Strands Agent to process the message with streaming
logger.info("Processing message with Strands Agent (streaming)...")
try:
# Stream response using agent.stream_async
stream = agent.stream_async(user_message)
async for event in stream:
logger.debug(f"Streaming event: {event}")
# Process different event types
if "data" in event:
# Text chunk from the model
chunk = event["data"]
formatted_chunk = format_response(chunk)
yield {
"type": "chunk",
"data": chunk,
"formatted": formatted_chunk
}
elif "current_tool_use" in event:
# Tool use information
tool_info = event["current_tool_use"]
yield {
"type": "tool_use",
"tool_name": tool_info.get("name", "Unknown tool"),
"tool_input": tool_info.get("input", {}),
"tool_id": tool_info.get("toolUseId", "")
}
elif "reasoning" in event and event["reasoning"]:
# Reasoning information
yield {
"type": "reasoning",
"reasoning_text": event.get("reasoningText", "")
}
elif "result" in event:
# Final result
result = event["result"]
if hasattr(result, 'message') and hasattr(result.message, 'content'):
if isinstance(result.message.content, list) and len(result.message.content) > 0:
final_response = result.message.content[0].get('text', '')
else:
final_response = str(result.message.content)
else:
final_response = str(result)
yield {
"type": "complete",
"final_response": format_response(final_response)
}
else:
# Pass through other events
yield event
except Exception as e:
logger.error(f"Error in streaming mode: {str(e)}", exc_info=True)
yield {"error": f"Error processing request with agent (streaming): {str(e)}"}
except Exception as e:
error_msg = f"Error processing request: {str(e)}"
logger.error(error_msg, exc_info=True)
yield {"error": error_msg}
if __name__ == "__main__":
# Test MCP server connection at startup
logger.info("Testing MCP server connection at startup...")
try:
jwt_token = access_token.get_gateway_access_token()
headers = {"Authorization": f"Bearer {jwt_token}", "Content-Type": "application/json"} if jwt_token else {"Content-Type": "application/json"}
payload = {
"jsonrpc": "2.0",
"id": "startup-test",
"method": "tools/list",
"params": {}
}
response = requests.post(f"{MCP_SERVER_URL}/mcp", headers=headers, json=payload, timeout=10)
logger.info(f"Direct test response status: {response.status_code}")
except Exception as e:
logger.error(f"Error in direct test: {str(e)}", exc_info=True)
# Run the AgentCore Runtime App
app.run()