mirror of
https://github.com/awslabs/amazon-bedrock-agentcore-samples.git
synced 2025-09-08 20:50:46 +00:00
* feat: Add AWS Operations Agent with AgentCore Runtime - Complete rewrite of AWS Operations Agent using Amazon Bedrock AgentCore - Added comprehensive deployment scripts for DIY and SDK runtime modes - Implemented OAuth2/PKCE authentication with Okta integration - Added MCP (Model Context Protocol) tool support for AWS service operations - Sanitized all sensitive information (account IDs, domains, client IDs) with placeholders - Added support for 17 AWS services: EC2, S3, Lambda, CloudFormation, IAM, RDS, CloudWatch, Cost Explorer, ECS, EKS, SNS, SQS, DynamoDB, Route53, API Gateway, SES, Bedrock, SageMaker - Includes chatbot client, gateway management scripts, and comprehensive testing - Ready for public GitHub with security-cleared configuration files Security: All sensitive values replaced with <YOUR_AWS_ACCOUNT_ID>, <YOUR_OKTA_DOMAIN>, <YOUR_OKTA_CLIENT_ID> placeholders * Update AWS Operations Agent architecture diagram --------- Co-authored-by: name <alias@amazon.com>
176 lines
7.9 KiB
Python
176 lines
7.9 KiB
Python
"""
|
|
AgentCore Configuration Validator
|
|
Validates AgentCore configuration against schema and business rules
|
|
"""
|
|
|
|
import re
|
|
import logging
|
|
from typing import Dict, Any, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConfigValidator:
|
|
"""Validate AgentCore configuration"""
|
|
|
|
def __init__(self):
|
|
"""Initialize configuration validator"""
|
|
self.arn_pattern = re.compile(r"^arn:aws:[\w-]+:[\w-]*:\d{12}:.*")
|
|
self.url_pattern = re.compile(r"^https?://[^\s/$.?#].[^\s]*$")
|
|
self.valid_log_levels = {"DEBUG", "INFO", "WARN", "WARNING", "ERROR", "CRITICAL"}
|
|
|
|
def validate_static(self, config: Dict[str, Any]) -> None:
|
|
"""Validate static configuration"""
|
|
self._validate_required_fields(config, ["aws", "agents", "okta"])
|
|
self._validate_aws_config(config.get("aws", {}))
|
|
self._validate_agent_config(config.get("agents", {}))
|
|
self._validate_okta_config(config.get("okta", {}))
|
|
|
|
# Validate tools schema if present
|
|
if "tools_schema" in config:
|
|
self._validate_tools_schema(config["tools_schema"])
|
|
|
|
def validate_dynamic(self, config: Dict[str, Any]) -> None:
|
|
"""Validate dynamic configuration"""
|
|
# Validate ARN formats
|
|
if "runtime" in config:
|
|
self._validate_runtime_arns(config["runtime"])
|
|
|
|
if "mcp_lambda" in config:
|
|
self._validate_mcp_lambda_config(config["mcp_lambda"])
|
|
|
|
# Validate URLs
|
|
if "gateway" in config:
|
|
self._validate_gateway_config(config["gateway"])
|
|
|
|
def _validate_required_fields(self, config: Dict[str, Any], required_fields: List[str]) -> None:
|
|
"""Validate that required fields exist"""
|
|
for field in required_fields:
|
|
if field not in config:
|
|
raise ValueError(f"Missing required configuration field: {field}")
|
|
|
|
def _validate_aws_config(self, aws_config: Dict[str, Any]) -> None:
|
|
"""Validate AWS configuration"""
|
|
if not aws_config.get("region"):
|
|
raise ValueError("AWS region is required")
|
|
|
|
if not aws_config.get("account_id"):
|
|
raise ValueError("AWS account ID is required")
|
|
|
|
# Validate account ID format (12 digits)
|
|
account_id = str(aws_config.get("account_id", ""))
|
|
if not re.match(r"^\d{12}$", account_id):
|
|
raise ValueError(f"Invalid AWS account ID format: {account_id}")
|
|
|
|
def _validate_agent_config(self, agents_config: Dict[str, Any]) -> None:
|
|
"""Validate agents configuration"""
|
|
if not agents_config.get("modelid"):
|
|
raise ValueError("Agent model ID is required")
|
|
|
|
# Validate max_concurrent if present
|
|
max_concurrent = agents_config.get("max_concurrent")
|
|
if max_concurrent is not None:
|
|
if not isinstance(max_concurrent, int) or max_concurrent < 1:
|
|
raise ValueError(f"max_concurrent must be a positive integer, got: {max_concurrent}")
|
|
|
|
def _validate_okta_config(self, okta_config: Dict[str, Any]) -> None:
|
|
"""Validate Okta configuration"""
|
|
if not okta_config.get("domain"):
|
|
raise ValueError("Okta domain is required")
|
|
|
|
jwt_config = okta_config.get("jwt", {})
|
|
if not jwt_config.get("audience"):
|
|
raise ValueError("Okta JWT audience is required")
|
|
|
|
if not jwt_config.get("discovery_url"):
|
|
raise ValueError("Okta JWT discovery URL is required")
|
|
|
|
# Validate discovery URL format
|
|
discovery_url = jwt_config.get("discovery_url", "")
|
|
if not self.url_pattern.match(discovery_url):
|
|
raise ValueError(f"Invalid Okta discovery URL format: {discovery_url}")
|
|
|
|
def _validate_tools_schema(self, tools_schema: List[Dict[str, Any]]) -> None:
|
|
"""Validate tools schema"""
|
|
if not isinstance(tools_schema, list):
|
|
raise ValueError("Tools schema must be a list")
|
|
|
|
for i, tool in enumerate(tools_schema):
|
|
if not isinstance(tool, dict):
|
|
raise ValueError(f"Tool {i} must be a dictionary")
|
|
|
|
if not tool.get("name"):
|
|
raise ValueError(f"Tool {i} missing required 'name' field")
|
|
|
|
if not tool.get("description"):
|
|
raise ValueError(f"Tool {i} missing required 'description' field")
|
|
|
|
if "inputSchema" not in tool:
|
|
raise ValueError(f"Tool {i} missing required 'inputSchema' field")
|
|
|
|
def _validate_runtime_arns(self, runtime_config: Dict[str, Any]) -> None:
|
|
"""Validate runtime ARN formats"""
|
|
for agent_type in ["diy_agent", "sdk_agent"]:
|
|
if agent_type in runtime_config:
|
|
agent_config = runtime_config[agent_type]
|
|
|
|
# Validate ARN format if present
|
|
arn = agent_config.get("arn")
|
|
if arn and not self.arn_pattern.match(arn):
|
|
raise ValueError(f"Invalid ARN format for {agent_type}: {arn}")
|
|
|
|
# Validate endpoint ARN format if present
|
|
endpoint_arn = agent_config.get("endpoint_arn")
|
|
if endpoint_arn and not self.arn_pattern.match(endpoint_arn):
|
|
raise ValueError(f"Invalid endpoint ARN format for {agent_type}: {endpoint_arn}")
|
|
|
|
def _validate_mcp_lambda_config(self, mcp_config: Dict[str, Any]) -> None:
|
|
"""Validate MCP lambda configuration"""
|
|
# Validate function ARN if present
|
|
function_arn = mcp_config.get("function_arn")
|
|
if function_arn and not self.arn_pattern.match(function_arn):
|
|
raise ValueError(f"Invalid MCP lambda function ARN format: {function_arn}")
|
|
|
|
# Validate role ARN if present
|
|
role_arn = mcp_config.get("role_arn")
|
|
if role_arn and not self.arn_pattern.match(role_arn):
|
|
raise ValueError(f"Invalid MCP lambda role ARN format: {role_arn}")
|
|
|
|
def _validate_gateway_config(self, gateway_config: Dict[str, Any]) -> None:
|
|
"""Validate gateway configuration"""
|
|
# Validate gateway URL if present
|
|
gateway_url = gateway_config.get("url")
|
|
if gateway_url and not self.url_pattern.match(gateway_url):
|
|
raise ValueError(f"Invalid gateway URL format: {gateway_url}")
|
|
|
|
# Validate gateway ARN if present
|
|
gateway_arn = gateway_config.get("arn")
|
|
if gateway_arn and not self.arn_pattern.match(gateway_arn):
|
|
raise ValueError(f"Invalid gateway ARN format: {gateway_arn}")
|
|
|
|
def _validate_sampling_rates(self, config: Dict[str, Any]) -> None:
|
|
"""Validate sampling rates are between 0.0 and 1.0"""
|
|
def check_sampling_rate(value: Any, path: str) -> None:
|
|
if isinstance(value, (int, float)):
|
|
if not (0.0 <= value <= 1.0):
|
|
raise ValueError(f"Sampling rate at {path} must be between 0.0 and 1.0, got: {value}")
|
|
|
|
# Check observability sampling rates
|
|
obs_config = config.get("observability", {})
|
|
if "tracing" in obs_config:
|
|
tracing = obs_config["tracing"]
|
|
if "sampling_rate" in tracing:
|
|
check_sampling_rate(tracing["sampling_rate"], "observability.tracing.sampling_rate")
|
|
|
|
def _validate_log_levels(self, config: Dict[str, Any]) -> None:
|
|
"""Validate log levels are valid"""
|
|
def check_log_level(value: Any, path: str) -> None:
|
|
if isinstance(value, str) and value.upper() not in self.valid_log_levels:
|
|
raise ValueError(f"Invalid log level at {path}: {value}. Valid levels: {', '.join(self.valid_log_levels)}")
|
|
|
|
# Check observability log levels
|
|
obs_config = config.get("observability", {})
|
|
if "logging" in obs_config:
|
|
logging_config = obs_config["logging"]
|
|
if "level" in logging_config:
|
|
check_log_level(logging_config["level"], "observability.logging.level") |