rohillasandeep 17a75597fe
fix (02-use-cases): AWS Operations Agent updated with AgentCore Runtime (#177)
* feat: Add AWS Operations Agent with AgentCore Runtime

- Complete rewrite of AWS Operations Agent using Amazon Bedrock AgentCore
- Added comprehensive deployment scripts for DIY and SDK runtime modes
- Implemented OAuth2/PKCE authentication with Okta integration
- Added MCP (Model Context Protocol) tool support for AWS service operations
- Sanitized all sensitive information (account IDs, domains, client IDs) with placeholders
- Added support for 17 AWS services: EC2, S3, Lambda, CloudFormation, IAM, RDS, CloudWatch, Cost Explorer, ECS, EKS, SNS, SQS, DynamoDB, Route53, API Gateway, SES, Bedrock, SageMaker
- Includes chatbot client, gateway management scripts, and comprehensive testing
- Ready for public GitHub with security-cleared configuration files

Security: All sensitive values replaced with <YOUR_AWS_ACCOUNT_ID>, <YOUR_OKTA_DOMAIN>, <YOUR_OKTA_CLIENT_ID> placeholders

* Update AWS Operations Agent architecture diagram

---------

Co-authored-by: name <alias@amazon.com>
2025-07-31 14:59:30 -04:00

176 lines
7.9 KiB
Python

"""
AgentCore Configuration Validator
Validates AgentCore configuration against schema and business rules
"""
import re
import logging
from typing import Dict, Any, List
logger = logging.getLogger(__name__)
class ConfigValidator:
"""Validate AgentCore configuration"""
def __init__(self):
"""Initialize configuration validator"""
self.arn_pattern = re.compile(r"^arn:aws:[\w-]+:[\w-]*:\d{12}:.*")
self.url_pattern = re.compile(r"^https?://[^\s/$.?#].[^\s]*$")
self.valid_log_levels = {"DEBUG", "INFO", "WARN", "WARNING", "ERROR", "CRITICAL"}
def validate_static(self, config: Dict[str, Any]) -> None:
"""Validate static configuration"""
self._validate_required_fields(config, ["aws", "agents", "okta"])
self._validate_aws_config(config.get("aws", {}))
self._validate_agent_config(config.get("agents", {}))
self._validate_okta_config(config.get("okta", {}))
# Validate tools schema if present
if "tools_schema" in config:
self._validate_tools_schema(config["tools_schema"])
def validate_dynamic(self, config: Dict[str, Any]) -> None:
"""Validate dynamic configuration"""
# Validate ARN formats
if "runtime" in config:
self._validate_runtime_arns(config["runtime"])
if "mcp_lambda" in config:
self._validate_mcp_lambda_config(config["mcp_lambda"])
# Validate URLs
if "gateway" in config:
self._validate_gateway_config(config["gateway"])
def _validate_required_fields(self, config: Dict[str, Any], required_fields: List[str]) -> None:
"""Validate that required fields exist"""
for field in required_fields:
if field not in config:
raise ValueError(f"Missing required configuration field: {field}")
def _validate_aws_config(self, aws_config: Dict[str, Any]) -> None:
"""Validate AWS configuration"""
if not aws_config.get("region"):
raise ValueError("AWS region is required")
if not aws_config.get("account_id"):
raise ValueError("AWS account ID is required")
# Validate account ID format (12 digits)
account_id = str(aws_config.get("account_id", ""))
if not re.match(r"^\d{12}$", account_id):
raise ValueError(f"Invalid AWS account ID format: {account_id}")
def _validate_agent_config(self, agents_config: Dict[str, Any]) -> None:
"""Validate agents configuration"""
if not agents_config.get("modelid"):
raise ValueError("Agent model ID is required")
# Validate max_concurrent if present
max_concurrent = agents_config.get("max_concurrent")
if max_concurrent is not None:
if not isinstance(max_concurrent, int) or max_concurrent < 1:
raise ValueError(f"max_concurrent must be a positive integer, got: {max_concurrent}")
def _validate_okta_config(self, okta_config: Dict[str, Any]) -> None:
"""Validate Okta configuration"""
if not okta_config.get("domain"):
raise ValueError("Okta domain is required")
jwt_config = okta_config.get("jwt", {})
if not jwt_config.get("audience"):
raise ValueError("Okta JWT audience is required")
if not jwt_config.get("discovery_url"):
raise ValueError("Okta JWT discovery URL is required")
# Validate discovery URL format
discovery_url = jwt_config.get("discovery_url", "")
if not self.url_pattern.match(discovery_url):
raise ValueError(f"Invalid Okta discovery URL format: {discovery_url}")
def _validate_tools_schema(self, tools_schema: List[Dict[str, Any]]) -> None:
"""Validate tools schema"""
if not isinstance(tools_schema, list):
raise ValueError("Tools schema must be a list")
for i, tool in enumerate(tools_schema):
if not isinstance(tool, dict):
raise ValueError(f"Tool {i} must be a dictionary")
if not tool.get("name"):
raise ValueError(f"Tool {i} missing required 'name' field")
if not tool.get("description"):
raise ValueError(f"Tool {i} missing required 'description' field")
if "inputSchema" not in tool:
raise ValueError(f"Tool {i} missing required 'inputSchema' field")
def _validate_runtime_arns(self, runtime_config: Dict[str, Any]) -> None:
"""Validate runtime ARN formats"""
for agent_type in ["diy_agent", "sdk_agent"]:
if agent_type in runtime_config:
agent_config = runtime_config[agent_type]
# Validate ARN format if present
arn = agent_config.get("arn")
if arn and not self.arn_pattern.match(arn):
raise ValueError(f"Invalid ARN format for {agent_type}: {arn}")
# Validate endpoint ARN format if present
endpoint_arn = agent_config.get("endpoint_arn")
if endpoint_arn and not self.arn_pattern.match(endpoint_arn):
raise ValueError(f"Invalid endpoint ARN format for {agent_type}: {endpoint_arn}")
def _validate_mcp_lambda_config(self, mcp_config: Dict[str, Any]) -> None:
"""Validate MCP lambda configuration"""
# Validate function ARN if present
function_arn = mcp_config.get("function_arn")
if function_arn and not self.arn_pattern.match(function_arn):
raise ValueError(f"Invalid MCP lambda function ARN format: {function_arn}")
# Validate role ARN if present
role_arn = mcp_config.get("role_arn")
if role_arn and not self.arn_pattern.match(role_arn):
raise ValueError(f"Invalid MCP lambda role ARN format: {role_arn}")
def _validate_gateway_config(self, gateway_config: Dict[str, Any]) -> None:
"""Validate gateway configuration"""
# Validate gateway URL if present
gateway_url = gateway_config.get("url")
if gateway_url and not self.url_pattern.match(gateway_url):
raise ValueError(f"Invalid gateway URL format: {gateway_url}")
# Validate gateway ARN if present
gateway_arn = gateway_config.get("arn")
if gateway_arn and not self.arn_pattern.match(gateway_arn):
raise ValueError(f"Invalid gateway ARN format: {gateway_arn}")
def _validate_sampling_rates(self, config: Dict[str, Any]) -> None:
"""Validate sampling rates are between 0.0 and 1.0"""
def check_sampling_rate(value: Any, path: str) -> None:
if isinstance(value, (int, float)):
if not (0.0 <= value <= 1.0):
raise ValueError(f"Sampling rate at {path} must be between 0.0 and 1.0, got: {value}")
# Check observability sampling rates
obs_config = config.get("observability", {})
if "tracing" in obs_config:
tracing = obs_config["tracing"]
if "sampling_rate" in tracing:
check_sampling_rate(tracing["sampling_rate"], "observability.tracing.sampling_rate")
def _validate_log_levels(self, config: Dict[str, Any]) -> None:
"""Validate log levels are valid"""
def check_log_level(value: Any, path: str) -> None:
if isinstance(value, str) and value.upper() not in self.valid_log_levels:
raise ValueError(f"Invalid log level at {path}: {value}. Valid levels: {', '.join(self.valid_log_levels)}")
# Check observability log levels
obs_config = config.get("observability", {})
if "logging" in obs_config:
logging_config = obs_config["logging"]
if "level" in logging_config:
check_log_level(logging_config["level"], "observability.logging.level")