2025-07-21 10:45:13 -04:00

1592 lines
58 KiB
Python

import json
import boto3
import psycopg2
import os
import re
import time
import logging
from datetime import datetime
from botocore.exceptions import ClientError
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QueryComplexityError(Exception):
"""Custom exception for query complexity violations"""
pass
class QueryLimitError(Exception):
"""Custom exception for query limit violations"""
pass
def analyze_query_complexity(query):
"""
Analyze query complexity and potential resource impact
Args:
query (str): SQL query to analyze
Returns:
dict: Complexity metrics
Raises:
QueryComplexityError: If query is too complex
"""
query_lower = query.lower()
complexity_score = 0
warnings = []
# Check for joins
join_count = sum(1 for join_type in ['join', 'inner join', 'left join', 'right join', 'full join']
if join_type in query_lower)
complexity_score += join_count * 2
if join_count > 3:
warnings.append(f"Query contains {join_count} joins - consider simplifying")
# Check for subqueries
subquery_count = query_lower.count('(select')
complexity_score += subquery_count * 3
if subquery_count > 2:
warnings.append(f"Query contains {subquery_count} subqueries - consider restructuring")
# Check for aggregations
agg_functions = ['count(', 'sum(', 'avg(', 'max(', 'min(']
agg_count = sum(query_lower.count(func) for func in agg_functions)
complexity_score += agg_count
# Check for window functions
if 'over(' in query_lower or 'partition by' in query_lower:
complexity_score += 3
warnings.append("Query uses window functions - monitor performance")
# Check for complex WHERE conditions
where_pos = query_lower.find('where')
if where_pos != -1:
where_clause = query_lower[where_pos:]
and_count = where_clause.count(' and ')
or_count = where_clause.count(' or ')
complexity_score += (and_count + or_count)
if (and_count + or_count) > 5:
warnings.append(f"Complex WHERE clause with {and_count + or_count} conditions")
return {
'complexity_score': complexity_score,
'warnings': warnings,
'join_count': join_count,
'subquery_count': subquery_count,
'aggregation_count': agg_count
}
def validate_and_execute_queries(secret_name, query, max_rows=20,
max_statements=5, max_total_rows=1000,
max_complexity=15):
"""
Enhanced query validation and execution with additional controls
"""
response = {
'results': [],
'performance_metrics': None,
'warnings': [],
'optimization_suggestions': []
}
start_time = time.time()
conn = None
total_rows = 0
try:
# Validate and split queries
statements = validate_query(query)
# Check number of statements
if len(statements) > max_statements:
raise QueryLimitError(
f"Too many statements ({len(statements)}). Maximum allowed is {max_statements}"
)
# Connect to database
conn = connect_to_db(secret_name)
with conn.cursor() as cur:
# Set session parameters
cur.execute("SET TRANSACTION READ ONLY")
cur.execute("SET statement_timeout = '30s'")
cur.execute("SET idle_in_transaction_session_timeout = '60s'")
# Execute each statement
for stmt_index, stmt in enumerate(statements, 1):
# Analyze query complexity
complexity_metrics = analyze_query_complexity(stmt)
if complexity_metrics['complexity_score'] > max_complexity:
raise QueryComplexityError(
f"Statement {stmt_index} is too complex (score: {complexity_metrics['complexity_score']})"
)
# Add complexity warnings to response
response['warnings'].extend(
f"Statement {stmt_index}: {warning}"
for warning in complexity_metrics['warnings']
)
stmt_response = {
'columns': [],
'rows': [],
'truncated': False,
'message': '',
'row_count': 0,
'query': stmt,
'complexity_metrics': complexity_metrics
}
stmt_lower = stmt.lower().strip()
# Only add LIMIT for SELECT queries
if stmt_lower.startswith('select') and 'limit' not in stmt_lower:
remaining_rows = max_total_rows - total_rows
limit_rows = min(max_rows, remaining_rows)
stmt = f"{stmt} LIMIT {limit_rows + 1}"
# Execute with explain plan first for SELECT queries
if stmt_lower.startswith('select'):
#cur.execute(f"EXPLAIN (FORMAT JSON) {stmt}")
#explain_plan = cur.fetchone()[0]
# Analyze plan for potential issues
optimization_suggestions = analyze_query_performance(secret_name, stmt)
if optimization_suggestions:
response['optimization_suggestions'].extend(
f"Statement {stmt_index}: {suggestion}"
for suggestion in optimization_suggestions
)
# Execute actual query
cur.execute(stmt)
# Get column names
stmt_response['columns'] = [desc[0] for desc in cur.description]
# Fetch results
rows = cur.fetchall()
row_count = len(rows)
# Check total row limit
total_rows += row_count
if total_rows > max_total_rows:
stmt_response['truncated'] = True
excess_rows = total_rows - max_total_rows
rows = rows[:-excess_rows]
stmt_response['message'] = (
f"Results truncated. Maximum total rows ({max_total_rows}) reached"
)
total_rows = max_total_rows
# Check individual statement limit
elif row_count > max_rows:
stmt_response['truncated'] = True
rows = rows[:max_rows]
stmt_response['message'] = (
f"Results truncated to {max_rows} rows"
)
stmt_response['row_count'] = len(rows)
stmt_response['rows'] = [
dict(zip(stmt_response['columns'], row))
for row in rows
]
response['results'].append(stmt_response)
# Add overall performance metrics
total_time = time.time() - start_time
response['performance_metrics'] = {
'execution_time': total_time,
'statements_executed': len(statements),
'total_rows': total_rows,
'timestamp': datetime.utcnow().isoformat(),
'needs_analysis': total_time > 5,
'performance_message': (
f"Executed {len(statements)} statements in {total_time:.2f} seconds"
)
}
# Add performance recommendations if needed
if total_time > 5:
response['warnings'].append(
"Query execution time exceeded 5 seconds. Consider optimization."
)
return response
except (QueryComplexityError, QueryLimitError) as e:
error_msg = str(e)
logger.warning(error_msg)
raise ValueError(error_msg)
except psycopg2.Error as pe:
error_msg = f"Database error: {str(pe)}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
logger.error(error_msg)
raise Exception(error_msg)
finally:
if conn:
conn.close()
def get_secret(secret_name):
"""Get secret from AWS Secrets Manager """
secret_name = secret_name
region_name = os.environ['REGION']
session = boto3.session.Session()
client = session.client(
service_name='secretsmanager',
region_name=region_name
)
try:
secret_value = client.get_secret_value(SecretId=secret_name)
secret = json.loads(secret_value['SecretString'])
return secret
except ClientError as e:
raise Exception(f"Failed to get secret: {str(e)}")
def get_env_secret(environment):
ssm_client = boto3.client('ssm')
"""Retrieve the secret name for the specified environment"""
if environment == 'prod':
try:
# Get the secret name from Parameter Store
response = ssm_client.get_parameter(
Name=f'/AuroraOps/{environment}'
)
print(response['Parameter']['Value'])
return response['Parameter']['Value']
except ssm_client.exceptions.ParameterNotFound:
error_message = f"Parameter not found: {parameter_name}"
print(error_message)
raise Exception(error_message)
elif environment == 'dev':
try:
# Get the secret name from Parameter Store
response = ssm_client.get_parameter(
Name=f'/AuroraOps/{environment}'
)
return response['Parameter']['Value']
except Exception as e:
raise Exception(f"Failed to get dev secret name from Parameter Store: {str(e)}")
else:
print("environement does not exist")
raise ValueError(f"Unknown environment: {environment}")
def connect_to_db(secret_name):
"""Establish database connection"""
cur_secret = secret_name
secret = get_secret(cur_secret)
try:
conn = psycopg2.connect(
host=secret['host'],
database=secret['dbname'],
user=secret['username'],
password=secret['password'],
port=secret['port']
)
return conn
except Exception as e:
raise Exception(f"Failed to connect to the database: {str(e)}")
# Define the queries dictionary for different object types
queries = {
'table': """
WITH RECURSIVE columns AS (
SELECT
t.schemaname,
t.tablename,
array_to_string(
array_agg(
' ' || quote_ident(a.attname) || ' ' ||
pg_catalog.format_type(a.atttypid, a.atttypmod) ||
CASE WHEN a.attnotnull THEN ' NOT NULL' ELSE '' END ||
CASE WHEN ad.adbin IS NOT NULL
THEN ' DEFAULT ' || pg_get_expr(ad.adbin, ad.adrelid)
ELSE ''
END
ORDER BY a.attnum
),
E',\n'
) as column_definitions
FROM pg_catalog.pg_tables t
JOIN pg_catalog.pg_class c
ON c.relname = t.tablename
AND c.relnamespace = (
SELECT oid
FROM pg_catalog.pg_namespace
WHERE nspname = t.schemaname
)
JOIN pg_catalog.pg_attribute a
ON a.attrelid = c.oid
AND a.attnum > 0
AND NOT a.attisdropped
LEFT JOIN pg_catalog.pg_attrdef ad
ON ad.adrelid = c.oid
AND ad.adnum = a.attnum
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
GROUP BY t.schemaname, t.tablename
)
SELECT
col.schemaname || '.' || col.tablename as object_name,
'TABLE' as object_type,
CASE
WHEN col.column_definitions IS NOT NULL THEN
format(
'CREATE TABLE %I.%I (\n%s\n);',
col.schemaname,
col.tablename,
col.column_definitions
)
ELSE 'ERROR: No columns found for this table'
END as definition,
obj_description(
(col.schemaname || '.' || col.tablename)::regclass,
'pg_class'
) as description
FROM columns col
WHERE col.tablename ILIKE %s AND col.schemaname = %s
""",
'view': """
SELECT
schemaname || '.' || viewname as object_name,
'VIEW' as object_type,
format(
'CREATE OR REPLACE VIEW %I.%I AS\n%s',
schemaname,
viewname,
pg_get_viewdef(format('%I.%I', schemaname, viewname)::regclass, true)
) as definition,
obj_description(
(schemaname || '.' || viewname)::regclass,
'pg_class'
) as description
FROM pg_catalog.pg_views
WHERE viewname ILIKE %s
AND schemaname = %s
AND schemaname NOT IN ('pg_catalog', 'information_schema')
""",
'function': """
SELECT
n.nspname || '.' || p.proname as object_name,
'FUNCTION' as object_type,
pg_get_functiondef(p.oid) as definition,
obj_description(p.oid, 'pg_proc') as description,
p.prorettype::regtype as return_type,
p.provolatile,
p.proparallel
FROM pg_catalog.pg_proc p
JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE p.proname ILIKE %s
AND n.nspname = %s
AND n.nspname NOT IN ('pg_catalog', 'information_schema')
AND p.prokind = 'f' -- 'f' for function
""",
'procedure': """
SELECT
n.nspname || '.' || p.proname as object_name,
'PROCEDURE' as object_type,
pg_get_functiondef(p.oid) as definition,
obj_description(p.oid, 'pg_proc') as description
FROM pg_catalog.pg_proc p
JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE p.proname ILIKE %s
AND n.nspname = %s
AND n.nspname NOT IN ('pg_catalog', 'information_schema')
AND p.prokind = 'p' -- 'p' for procedure
""",
'trigger': """
SELECT
n.nspname || '.' || t.tgname as object_name,
'TRIGGER' as object_type,
pg_get_triggerdef(t.oid, true) as definition,
obj_description(t.oid, 'pg_trigger') as description
FROM pg_catalog.pg_trigger t
JOIN pg_catalog.pg_class c ON t.tgrelid = c.oid
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE t.tgname ILIKE %s
AND n.nspname = %s
AND n.nspname NOT IN ('pg_catalog', 'information_schema')
AND NOT t.tgisinternal
""",
'sequence': """
SELECT
n.nspname || '.' || c.relname as object_name,
'SEQUENCE' as object_type,
format(
'CREATE SEQUENCE %I.%I\n INCREMENT %s\n MINVALUE %s\n MAXVALUE %s\n START %s\n CACHE %s%s;',
n.nspname,
c.relname,
s.seqincrement,
s.seqmin,
s.seqmax,
s.seqstart,
s.seqcache,
CASE WHEN s.seqcycle THEN '\n CYCLE' ELSE '' END
) as definition,
obj_description(c.oid, 'pg_class') as description
FROM pg_catalog.pg_sequence s
JOIN pg_catalog.pg_class c ON s.seqrelid = c.oid
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE c.relname ILIKE %s
AND n.nspname = %s
AND n.nspname NOT IN ('pg_catalog', 'information_schema')
""",
'index': """
SELECT
n.nspname || '.' || c.relname as object_name,
'INDEX' as object_type,
pg_get_indexdef(i.indexrelid) as definition,
obj_description(i.indexrelid, 'pg_class') as description
FROM pg_catalog.pg_index i
JOIN pg_catalog.pg_class c ON i.indexrelid = c.oid
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE c.relname ILIKE %s
AND n.nspname = %s
AND n.nspname NOT IN ('pg_catalog', 'information_schema')
"""
}
def extract_database_object_ddl(secret_name, object_type, object_name=None, object_schema=None):
"""
Extract DDL and description for database objects
Args:
secret_name (str): The name of the secret containing database credentials
object_type (str): Type of database object ('table', 'view', 'function', 'procedure', etc.)
object_name (str, optional): Name of the object to search for
object_schema (str, optional): Schema name to filter objects
Returns:
list: List of dictionaries containing object information
str: Error message if no objects found
"""
try:
# Input validation
if not object_name or not object_schema:
raise ValueError("Both object_name and object_schema are required")
# Validate object_type
object_type_lower = object_type.lower()
if object_type_lower not in queries:
valid_types = ', '.join(queries.keys())
raise ValueError(f"Invalid object_type: {object_type}. Valid types are: {valid_types}")
# Connect to database
conn = connect_to_db(secret_name)
if not conn:
raise Exception("Failed to establish database connection")
results = []
with conn.cursor() as cur:
try:
# Debug information
print(f"\nExecuting query for {object_type_lower}")
print(f"Object name: {object_name}")
print(f"Schema: {object_schema}")
# Get query and execute
query = queries[object_type_lower]
params = [object_name, object_schema]
print("\nExecuting query with parameters:")
print(f"Query: {query}")
print(f"Parameters: {params}")
cur.execute(query, params)
# Fetch results
rows = cur.fetchall()
if not rows:
print("No rows returned from query")
return "No matching objects found"
# Get column names
columns = [desc[0] for desc in cur.description]
print(f"\nColumns returned: {columns}")
print(f"Number of rows: {len(rows)}")
# Process results
for row in rows:
if not row:
continue
# Create result dictionary
result = {}
for i, column in enumerate(columns):
result[column] = row[i] if i < len(row) else None
# Add explanation based on object type
if result.get('definition'):
if object_type_lower == 'table':
result['explanation'] = analyze_table_definition(result['definition'])
elif object_type_lower == 'view':
result['explanation'] = analyze_view_definition(result['definition'])
elif object_type_lower in ('function', 'procedure'):
result['explanation'] = analyze_routine_definition(result['definition'])
elif object_type_lower == 'trigger':
result['explanation'] = analyze_trigger_definition(result['definition'])
else:
result['explanation'] = f"DDL for {object_type_lower}"
results.append(result)
print(f"\nProcessed {object_type_lower}: {result.get('object_name', 'unknown')}")
except Exception as e:
error_msg = f"Error executing query: {str(e)}"
print(f"\nError details:")
print(f"Query: {query}")
print(f"Parameters: {params}")
print(f"Error message: {error_msg}")
raise
# Return results
if not results:
return "No matching objects found"
print(f"\nSuccessfully retrieved {len(results)} objects")
return results
except Exception as e:
error_msg = f"Failed to extract database object DDL: {str(e)}"
print(f"\nError: {error_msg}")
raise
finally:
if conn:
try:
conn.close()
print("\nDatabase connection closed")
except Exception as e:
print(f"\nError closing connection: {str(e)}")
def analyze_table_definition(definition):
"""Analyze table DDL and return explanatory notes"""
explanation = ["This table contains the following structure:"]
# Extract column definitions
columns = []
for line in definition.split('\n'):
if 'CREATE TABLE' not in line and '(' not in line and ')' not in line:
if line.strip():
columns.append(line.strip())
# Analyze columns
for column in columns:
if column.endswith(','):
column = column[:-1]
parts = column.split()
if len(parts) >= 2:
col_name = parts[0]
col_type = parts[1]
constraints = ' '.join(parts[2:])
explanation.append(f"- {col_name}: {col_type} {constraints}")
return '\n'.join(explanation)
def generate_object_explanation(obj_info):
"""
Generate a human-readable explanation of the database object
Args:
obj_info (dict): Dictionary containing object information
Returns:
str: Human-readable explanation of the object
"""
try:
definition = obj_info.get('definition', '')
obj_type = obj_info.get('object_type', '')
explanation = []
# Add object description if available
description = obj_info.get('description', '')
if description:
explanation.append(f"Description: {description}")
# Analyze based on object type
if obj_type == 'TABLE':
explanation.extend(analyze_table_definition(definition))
elif obj_type == 'VIEW':
explanation.extend(analyze_view_definition(definition))
elif obj_type in ('FUNCTION', 'PROCEDURE'):
explanation.extend(analyze_routine_definition(definition))
return '\n'.join(explanation) if explanation else "No explanation available"
except Exception as e:
print(f"Error generating explanation: {str(e)}")
return "Error generating explanation"
def analyze_view_definition(definition):
"""Analyze view DDL and return explanatory notes"""
explanation = ["This view represents the following:"]
# Clean up the definition
clean_def = definition.replace('\n', ' ').strip()
# Extract main components
if 'SELECT' in clean_def:
explanation.append("This view performs a SELECT operation with the following characteristics:")
# Analyze SELECT clause
if 'JOIN' in clean_def:
explanation.append("- Joins multiple tables")
if 'WHERE' in clean_def:
explanation.append("- Applies filtering conditions")
if 'GROUP BY' in clean_def:
explanation.append("- Aggregates data")
if 'HAVING' in clean_def:
explanation.append("- Applies post-aggregation filters")
if 'ORDER BY' in clean_def:
explanation.append("- Sorts the results")
if 'UNION' in clean_def:
explanation.append("- Combines multiple result sets")
if 'WITH' in clean_def:
explanation.append("- Uses Common Table Expressions (CTEs)")
return '\n'.join(explanation)
def analyze_routine_definition(definition):
"""Analyze function/procedure DDL and return explanatory notes"""
explanation = []
# Determine if it's a function or procedure
if 'FUNCTION' in definition:
explanation.append("This is a function that:")
else:
explanation.append("This is a procedure that:")
# Extract parameters
if '(' in definition:
param_section = definition[definition.find('(')+1:definition.find(')')]
params = param_section.split(',')
if params and params[0].strip():
explanation.append("\nParameters:")
for param in params:
explanation.append(f"- {param.strip()}")
# Identify return type for functions
if 'RETURNS' in definition:
return_type = definition[definition.find('RETURNS')+7:].split()[0]
explanation.append(f"\nReturns: {return_type}")
# Analyze body
if 'BEGIN' in definition:
explanation.append("\nLogic overview:")
if 'IF' in definition:
explanation.append("- Contains conditional logic")
if 'LOOP' in definition or 'WHILE' in definition:
explanation.append("- Contains loops")
if 'INSERT' in definition:
explanation.append("- Performs data insertion")
if 'UPDATE' in definition:
explanation.append("- Performs data updates")
if 'DELETE' in definition:
explanation.append("- Performs data deletion")
if 'SELECT' in definition:
explanation.append("- Retrieves data")
if 'EXCEPTION' in definition:
explanation.append("- Includes error handling")
return '\n'.join(explanation)
def analyze_trigger_definition(definition):
"""Analyze trigger DDL and return explanatory notes"""
explanation = ["This trigger:"]
if 'BEFORE' in definition:
explanation.append("- Executes BEFORE the event")
elif 'AFTER' in definition:
explanation.append("- Executes AFTER the event")
if 'INSERT' in definition:
explanation.append("- Fires on INSERT")
if 'UPDATE' in definition:
explanation.append("- Fires on UPDATE")
if 'DELETE' in definition:
explanation.append("- Fires on DELETE")
if 'FOR EACH ROW' in definition:
explanation.append("- Executes for each affected row")
elif 'FOR EACH STATEMENT' in definition:
explanation.append("- Executes once per statement")
return '\n'.join(explanation)
def clean_query_for_explain(query):
"""
Remove any existing EXPLAIN or EXPLAIN ANALYZE keywords from the query
Parameters:
- query: Original query string
Returns:
- Cleaned query string
"""
# Remove common EXPLAIN variants (case-insensitive)
patterns = [
r'^\s*EXPLAIN\s+ANALYZE\s+',
r'^\s*EXPLAIN\s+\(.*?\)\s+',
r'^\s*EXPLAIN\s+'
]
cleaned_query = query
for pattern in patterns:
cleaned_query = re.sub(pattern, '', cleaned_query, flags=re.IGNORECASE)
return cleaned_query.strip()
def analyze_query_performance(secret_name, query_or_object_name, parameters=None, object_type=None):
"""
Analyze query performance and provide optimization recommendations
Parameters:
- secret_name: Secret containing database credentials
- query_or_object_name: SQL query string or object name to analyze
- parameters: Optional. List of parameter values for parameterized queries
- object_type: Optional. If provided, will fetch definition from database object
"""
conn = connect_to_db(secret_name)
try:
with conn.cursor() as cur:
# If object_type is provided, fetch the query definition
if object_type:
query_to_analyze = get_object_definition(cur, query_or_object_name, object_type)
else:
query_to_analyze = query_or_object_name
# Clean the query before analysis
query_to_analyze = clean_query_for_explain(query_to_analyze)
# Check if the query contains parameter placeholders
has_parameters = any(f'${i}' in query_to_analyze for i in range(1, 21))
if has_parameters:
# Replace $n parameters with dummy placeholders
modified_query = query_to_analyze
param_count = 0
for i in range(1, 21):
if f'${i}' in modified_query:
param_count = max(param_count, i)
modified_query = modified_query.replace(f'${i}', 'NULL')
# Use GENERIC_PLAN for parameterized queries
cur.execute(f"EXPLAIN (GENERIC_PLAN, BUFFERS, FORMAT JSON) {modified_query}")
plan = cur.fetchone()[0]
cur.execute(f"EXPLAIN (FORMAT JSON) {modified_query}")
estimated_plan = cur.fetchone()[0]
# Pass True for is_generic_plan
analysis = analyze_execution_plan(plan[0], estimated_plan[0], True)
else:
# For non-parameterized queries, use ANALYZE
cur.execute(f"EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {query_to_analyze}")
plan = cur.fetchone()[0]
cur.execute(f"EXPLAIN (FORMAT JSON) {query_to_analyze}")
estimated_plan = cur.fetchone()[0]
# Pass False for is_generic_plan
analysis = analyze_execution_plan(plan[0], estimated_plan[0], False)
return analysis
except Exception as e:
raise Exception(f"Failed to analyze query performance: {str(e)}")
finally:
if conn:
conn.close()
def analyze_execution_plan(actual_plan, estimated_plan, is_generic_plan):
"""
Analyze execution plan and provide detailed explanations and recommendations
"""
analysis = {
'summary': [],
'issues': [],
'recommendations': [],
'performance_stats': {}
}
analysis['plan_type'] = 'Generic Plan' if is_generic_plan else 'Analyzed Plan'
# Extract key performance metrics
total_cost = actual_plan['Plan'].get('Total Cost')
rows = actual_plan['Plan'].get('Plan Rows') # Use Plan Rows for generic plans
estimated_rows = estimated_plan['Plan'].get('Plan Rows')
# Initialize performance stats
performance_stats = {
'total_cost': total_cost,
'estimated_rows': estimated_rows,
'plan_rows': rows
}
# Add actual execution metrics only for analyzed plans
if not is_generic_plan:
actual_time = actual_plan['Plan'].get('Actual Total Time')
actual_rows = actual_plan['Plan'].get('Actual Rows')
performance_stats.update({
'execution_time_ms': actual_time,
'actual_rows': actual_rows
})
analysis['performance_stats'] = performance_stats
# Analyze node types and operations
analyze_plan_node(actual_plan['Plan'], analysis, is_generic_plan)
# Check for common issues
identify_performance_issues(actual_plan['Plan'], estimated_plan['Plan'], analysis, is_generic_plan)
# Generate recommendations
generate_recommendations(analysis)
return analysis
def analyze_plan_node(node, analysis, is_generic_plan):
"""
Recursively analyze each node in the execution plan
"""
# Analyze current node
node_type = node['Node Type']
# Check for expensive operations with appropriate metrics based on plan type
if node_type == 'Seq Scan':
analysis['issues'].append({
'type': 'sequential_scan',
'description': f"Sequential scan detected on table {node.get('Relation Name')}",
'severity': 'high'
})
elif node_type == 'Nested Loop':
if is_generic_plan:
if node.get('Plan Rows', 0) > 1000:
analysis['issues'].append({
'type': 'nested_loop_large_dataset',
'description': "Nested loop join planned for large dataset",
'severity': 'medium'
})
else:
if node.get('Actual Rows', 0) > 1000:
analysis['issues'].append({
'type': 'nested_loop_large_dataset',
'description': "Nested loop join performed on large dataset",
'severity': 'medium'
})
elif node_type == 'Hash Join':
rows_metric = node.get('Plan Rows' if is_generic_plan else 'Actual Rows', 0)
if node.get('Hash Cond') and rows_metric > 10000:
analysis['issues'].append({
'type': 'large_hash_join',
'description': "Large hash join operation detected",
'severity': 'medium'
})
# Check for filter conditions
if 'Filter' in node:
analyze_filter_condition(node['Filter'], analysis)
# Recursively analyze child nodes
for child in node.get('Plans', []):
analyze_plan_node(child, analysis, is_generic_plan)
def analyze_filter_condition(filter_condition, analysis):
"""
Analyze filter conditions for potential optimization opportunities
"""
filter_lower = filter_condition.lower()
# Check for function calls in WHERE clause
if '(' in filter_condition and ')' in filter_condition:
analysis['issues'].append({
'type': 'function_in_filter',
'description': "Function call in WHERE clause may prevent index usage",
'severity': 'medium'
})
# Check for LIKE operations
if ' like ' in filter_lower and filter_lower.startswith('%'):
analysis['issues'].append({
'type': 'leading_wildcard',
'description': "Leading wildcard in LIKE clause prevents index usage",
'severity': 'medium'
})
def identify_performance_issues(actual_node, estimated_node, analysis, is_generic_plan):
"""
Identify performance issues by comparing actual vs estimated plans
"""
if not is_generic_plan:
# Row estimation analysis only for actual execution plans
if estimated_node.get('Plan Rows', 0) > 0:
estimation_ratio = actual_node.get('Actual Rows', 0) / estimated_node['Plan Rows']
if estimation_ratio > 10 or estimation_ratio < 0.1:
analysis['issues'].append({
'type': 'poor_statistics',
'description': f"Statistics may be outdated - row estimation is off by factor of {estimation_ratio:.1f}",
'severity': 'high'
})
# Parallel execution analysis (applies to both plan types)
if actual_node.get('Workers Planned', 0) > 0 and actual_node.get('Workers Launched', 0) == 0:
analysis['issues'].append({
'type': 'parallel_execution_failed',
'description': "Parallel execution was planned but not executed",
'severity': 'medium'
})
def generate_recommendations(analysis):
"""
Generate specific recommendations based on identified issues
"""
for issue in analysis['issues']:
if issue['type'] == 'sequential_scan':
analysis['recommendations'].append({
'issue': 'Sequential Scan Detected',
'recommendation': """
Consider the following solutions:
1. Create an index on the commonly queried columns
2. Review WHERE clause conditions for index compatibility
3. Ensure statistics are up to date with ANALYZE
Example index creation:
CREATE INDEX idx_name ON table_name (column_name);
"""
})
elif issue['type'] == 'poor_statistics':
analysis['recommendations'].append({
'issue': 'Statistics Mismatch',
'recommendation': """
Update statistics for more accurate query planning:
1. Run ANALYZE on the affected tables
2. Consider increasing statistics target:
ALTER TABLE table_name ALTER COLUMN column_name SET STATISTICS 1000;
3. Review and possibly update auto_vacuum settings
"""
})
elif issue['type'] == 'function_in_filter':
analysis['recommendations'].append({
'issue': 'Function in WHERE Clause',
'recommendation': """
Optimize filter conditions:
1. Remove function calls from WHERE clause
2. Consider creating a computed column with an index
3. Rewrite the condition to use direct column comparisons
Example:
Instead of: WHERE UPPER(column) = 'VALUE'
Use: WHERE column = LOWER('VALUE')
"""
})
def format_analysis_output(analysis):
"""
Format the analysis results into human-readable text
"""
output = []
# Performance Statistics
output.append("Query Performance Summary:")
# Check if this is a generic plan
is_generic_plan = analysis.get('plan_type') == 'Generic Plan'
if is_generic_plan:
# For generic plans, show estimated metrics
output.append(f"- Plan Type: Generic Plan (Parameterized Query)")
output.append(f"- Estimated Total Cost: {analysis['performance_stats'].get('total_cost', 'N/A')}")
output.append(f"- Estimated Rows: {analysis['performance_stats'].get('estimated_rows', 'N/A')}")
output.append(f"- Plan Rows: {analysis['performance_stats'].get('plan_rows', 'N/A')}")
else:
# For actual execution plans, show actual metrics
output.append(f"- Plan Type: Analyzed Plan")
output.append(f"- Execution Time: {analysis['performance_stats'].get('execution_time_ms', 'N/A'):.2f} ms")
output.append(f"- Actual Rows: {analysis['performance_stats'].get('actual_rows', 'N/A')}")
output.append(f"- Estimated Rows: {analysis['performance_stats'].get('estimated_rows', 'N/A')}")
output.append("")
# Issues
if analysis['issues']:
output.append("Identified Issues:")
for issue in analysis['issues']:
output.append(f"- {issue['description']} (Severity: {issue['severity']})")
output.append("")
# Recommendations
if analysis['recommendations']:
output.append("Recommendations:")
for rec in analysis['recommendations']:
output.append(f"Problem: {rec['issue']}")
output.append(f"Solution: {rec['recommendation']}")
output.append("")
return "\n".join(output)
def monitor_query_performance(query, start_time, rows_returned):
"""
Monitor query performance and suggest analysis if needed
Args:
query (str): Executed SQL query
start_time (float): Query start timestamp
rows_returned (int): Number of rows returned
Returns:
dict: Performance metrics and analysis suggestion
"""
execution_time = time.time() - start_time
metrics = {
'query': query,
'execution_time': execution_time,
'rows_returned': rows_returned,
'timestamp': datetime.utcnow().isoformat(),
'needs_analysis': False,
'performance_message': ''
}
# Define performance thresholds
SLOW_QUERY_THRESHOLD = 5 # seconds
HIGH_ROWS_THRESHOLD = 10000
# Check for performance issues
performance_issues = []
if execution_time > SLOW_QUERY_THRESHOLD:
performance_issues.append(f"Query took {execution_time:.2f} seconds to execute")
metrics['needs_analysis'] = True
if rows_returned > HIGH_ROWS_THRESHOLD:
performance_issues.append(f"Query returned {rows_returned} rows")
metrics['needs_analysis'] = True
if metrics['needs_analysis']:
metrics['performance_message'] = (
"⚠️ Performance Warning:\n"
f"{'; '.join(performance_issues)}.\n"
"Would you like me to analyze this query for potential optimizations? "
"Reply with 'yes' to get performance recommendations."
)
logger.warning(f"Slow query detected: {query}")
else:
metrics['performance_message'] = f"Query executed successfully in {execution_time:.2f} seconds"
return metrics
def validate_query(query):
"""
Validate query for security concerns and split into statements
Args:
query (str): SQL query to validate
Returns:
list: List of validated statements
Raises:
ValueError: If query contains prohibited operations
"""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
def is_within_quotes(text, position):
"""Check if a position in text is within quotes"""
single_quotes = False
double_quotes = False
for i in range(position):
if text[i] == "'" and not double_quotes:
single_quotes = not single_quotes
elif text[i] == '"' and not single_quotes:
double_quotes = not double_quotes
return single_quotes or double_quotes
def split_statements(query_text):
"""Split query into individual statements, respecting quotes and comments"""
statements = []
current_stmt = []
i = 0
comment_block = False
line_comment = False
while i < len(query_text):
char = query_text[i]
# Handle comment blocks
if query_text[i:i+2] == '/*' and not line_comment:
comment_block = True
current_stmt.append(char)
i += 1
elif query_text[i:i+2] == '*/' and comment_block:
comment_block = False
current_stmt.append(char)
i += 1
# Handle line comments
elif query_text[i:i+2] == '--' and not comment_block:
line_comment = True
current_stmt.append(char)
i += 1
elif char == '\n' and line_comment:
line_comment = False
current_stmt.append(char)
# Handle semicolons
elif char == ';' and not comment_block and not line_comment and not is_within_quotes(query_text, i):
current_stmt.append(char)
stmt = ''.join(current_stmt).strip()
if stmt:
statements.append(stmt)
current_stmt = []
else:
current_stmt.append(char)
i += 1
# Add the last statement if exists
last_stmt = ''.join(current_stmt).strip()
if last_stmt:
statements.append(last_stmt)
return [stmt for stmt in statements if stmt]
# Split into statements
statements = split_statements(query)
validated_statements = []
# Validate each statement
for stmt in statements:
stmt = stmt.strip()
if stmt.endswith(';'):
stmt = stmt[:-1]
stmt_lower = stmt.lower().strip()
# Get the command type
first_word = stmt_lower.split()[0] if stmt_lower.split() else ''
if first_word not in ['select', 'show']:
raise ValueError(f"Prohibited operation detected: {first_word}")
# For SELECT statements, check for dangerous operations
if first_word == 'select':
dangerous_operations = [
r'\binsert\b', r'\bupdate\b', r'\bdelete\b', r'\bdrop\b',
r'\btruncate\b', r'\balter\b', r'\bcreate\b', r'\bgrant\b',
r'\brevoke\b', r'\bexecute\b', r'\bcopy\b'
]
# Remove content within quotes for checking
query_for_check = ''
in_quote = False
quote_char = None
for char in stmt:
if char in ["'", '"'] and (not quote_char or char == quote_char):
if not in_quote:
quote_char = char
in_quote = True
else:
quote_char = None
in_quote = False
elif not in_quote:
query_for_check += char
# Check for dangerous operations
for operation in dangerous_operations:
if re.search(operation, query_for_check.lower()):
raise ValueError(f"Statement contains prohibited operation: {operation}")
validated_statements.append(stmt)
return validated_statements
def execute_read_query(secret_name, query, max_rows=20):
"""
Execute read-only queries safely and return results with monitoring
Args:
secret_name (str): Secret containing database credentials
query (str): SQL query to execute
max_rows (int): Maximum number of rows to return (only for SELECT queries)
Returns:
dict: Query results and metadata
"""
response = {
'results': [],
'performance_metrics': None,
'warnings': [],
'optimization_suggestions': []
}
start_time = time.time()
conn = None
try:
# Validate and split queries
statements = validate_query(query)
# Connect to database
conn = connect_to_db(secret_name)
with conn.cursor() as cur:
# Set session to read-only and timeout
cur.execute("SET TRANSACTION READ ONLY")
cur.execute("SET statement_timeout = '30s'")
# Execute each statement
for stmt_index, stmt in enumerate(statements, 1):
stmt_response = {
'columns': [],
'rows': [],
'truncated': False,
'message': '',
'row_count': 0,
'query': stmt
}
# Determine if it's a SELECT query
stmt_lower = stmt.lower().strip()
is_select_query = stmt_lower.lstrip('(').startswith('select')
# Prepare the final query
final_query = stmt
if is_select_query and 'limit' not in stmt_lower:
final_query = f"{stmt} LIMIT {max_rows + 1}"
# Execute query
try:
cur.execute(final_query)
except psycopg2.Error as pe:
logger.error(f"Error executing query: {final_query}")
logger.error(f"Error details: {str(pe)}")
raise
# Get column names
stmt_response['columns'] = [desc[0] for desc in cur.description]
# Fetch results
rows = cur.fetchall()
total_rows = len(rows)
# Handle row limiting only for SELECT queries
if is_select_query and total_rows > max_rows:
stmt_response['truncated'] = True
rows = rows[:max_rows]
stmt_response['message'] = (
f"Results truncated to {max_rows} rows for performance reasons. "
f"Total rows available: {total_rows}"
)
stmt_response['row_count'] = max_rows
else:
stmt_response['row_count'] = total_rows
# Convert rows to list of dictionaries
stmt_response['rows'] = [
dict(zip(stmt_response['columns'], row))
for row in rows
]
# Add performance monitoring only for SELECT queries
if is_select_query:
complexity_metrics = analyze_query_complexity(stmt)
stmt_response['complexity_metrics'] = complexity_metrics
# Add complexity warnings if any
if complexity_metrics['warnings']:
response['warnings'].extend(
f"Statement {stmt_index}: {warning}"
for warning in complexity_metrics['warnings']
)
# Store the original query in the response
stmt_response['query'] = f"{stmt};"
response['results'].append(stmt_response)
# Add overall performance metrics
total_time = time.time() - start_time
response['performance_metrics'] = {
'execution_time': total_time,
'statements_executed': len(statements),
'timestamp': datetime.utcnow().isoformat(),
'needs_analysis': total_time > 5,
'performance_message': (
f"Executed {len(statements)} statements in {total_time:.2f} seconds"
)
}
return response
except ValueError as ve:
error_msg = f"Query validation failed: {str(ve)}"
logger.error(error_msg)
raise ValueError(error_msg)
except psycopg2.Error as pe:
error_msg = f"Database error: {str(pe)}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
logger.error(error_msg)
raise Exception(error_msg)
finally:
if conn:
conn.close()
def format_enhanced_results(results):
"""
Format results with enhanced information
"""
formatted_output = []
# Add performance summary
metrics = results['performance_metrics']
formatted_output.append("Query Execution Summary:")
formatted_output.append(f"- Total execution time: {metrics['execution_time']:.2f} seconds")
formatted_output.append(f"- Statements executed: {metrics['statements_executed']}")
formatted_output.append(f"- Total rows returned: {metrics['total_rows']}")
formatted_output.append("")
# Add warnings if any
if results['warnings']:
formatted_output.append("Warnings:")
for warning in results['warnings']:
formatted_output.append(f"- {warning}")
formatted_output.append("")
# Add optimization suggestions if any
if results['optimization_suggestions']:
formatted_output.append("Optimization Suggestions:")
for suggestion in results['optimization_suggestions']:
formatted_output.append(f"- {suggestion}")
formatted_output.append("")
# Format each statement's results
for i, result in enumerate(results['results'], 1):
formatted_output.append(f"Statement {i}:")
formatted_output.append(f"Query: {result['query']}")
# Add complexity metrics
complexity = result['complexity_metrics']
formatted_output.append("Complexity Analysis:")
formatted_output.append(f"- Score: {complexity['complexity_score']}")
formatted_output.append(f"- Joins: {complexity['join_count']}")
formatted_output.append(f"- Subqueries: {complexity['subquery_count']}")
formatted_output.append(f"- Aggregations: {complexity['aggregation_count']}")
if result['message']:
formatted_output.append(f"Note: {result['message']}")
if result['columns']:
# Calculate column widths
widths = {
col: max(len(str(col)),
max(len(str(row[col])) for row in result['rows']))
for col in result['columns']
}
# Add header
header = " | ".join(
str(col).ljust(widths[col])
for col in result['columns']
)
formatted_output.append(header)
formatted_output.append("-" * len(header))
# Add rows
for row in result['rows']:
formatted_output.append(" | ".join(
str(row[col]).ljust(widths[col])
for col in result['columns']
))
formatted_output.append(f"Rows returned: {result['row_count']}")
formatted_output.append("")
return "\n".join(formatted_output)
def format_query_results(results):
"""
Format query results for display
Args:
results (dict): Query execution results
Returns:
str: Formatted results string
"""
formatted_output = []
# Add performance message first
if results['performance_metrics'] and results['performance_metrics']['performance_message']:
formatted_output.append(results['performance_metrics']['performance_message'] + "\n")
# Add truncation message if applicable
if results['message']:
formatted_output.append(f"Note: {results['message']}\n")
# Add column headers
if results['columns']:
# Calculate column widths
widths = {}
for col in results['columns']:
widths[col] = len(str(col))
for row in results['rows']:
widths[col] = max(widths[col], len(str(row[col])))
# Create header
header = " | ".join(
str(col).ljust(widths[col])
for col in results['columns']
)
formatted_output.append(header)
# Add separator
separator = "-" * len(header)
formatted_output.append(separator)
# Add rows
for row in results['rows']:
formatted_row = " | ".join(
str(row[col]).ljust(widths[col])
for col in results['columns']
)
formatted_output.append(formatted_row)
# Add summary
formatted_output.append(f"\nTotal rows: {results['row_count']}")
return "\n".join(formatted_output)
def format_multi_query_results(results):
"""Format results from multiple statements"""
formatted_output = []
# Add performance summary
metrics = results['performance_metrics']
formatted_output.append(f"Query Execution Summary:")
formatted_output.append(f"- Total execution time: {metrics['execution_time']:.2f} seconds")
formatted_output.append(f"- Statements executed: {metrics['statements_executed']}")
formatted_output.append("")
# Format each statement's results
for i, result in enumerate(results['results'], 1):
formatted_output.append(f"Statement {i}: {result['query']}")
if result['message']:
formatted_output.append(f"Note: {result['message']}")
if result['columns']:
# Calculate column widths
widths = {
col: max(len(str(col)),
max(len(str(row[col])) for row in result['rows']))
for col in result['columns']
}
# Add header
header = " | ".join(
str(col).ljust(widths[col])
for col in result['columns']
)
formatted_output.append(header)
formatted_output.append("-" * len(header))
# Add rows
for row in result['rows']:
formatted_output.append(" | ".join(
str(row[col]).ljust(widths[col])
for col in result['columns']
))
formatted_output.append(f"Rows returned: {result['row_count']}")
formatted_output.append("")
return "\n".join(formatted_output)
def lambda_handler(event, context):
try:
environment = event['environment']
action_type = event['action_type']
secret_name = get_env_secret(environment)
min_exec_time = 1000
# Get explain plan for a query
if action_type == 'explain_query':
query = event['query']
print("Executing explain query scripts")
results = analyze_query_performance(secret_name, query)
formatted_results = format_analysis_output(results)
elif action_type == 'extract_ddl':
object_type = event['object_type']
object_name = event['object_name']
object_schema = event['object_schema']
print("Generating the DDL scripts for the object")
results = extract_database_object_ddl(secret_name, object_type=object_type, object_name=object_name, object_schema=object_schema)
# Convert results to string if it's not already
formatted_results = str(results) if results else "No results found"
elif action_type == 'execute_query':
query = event['query']
print("Executing read-only queries")
results = validate_and_execute_queries(
secret_name,
query,
max_rows=20,
max_statements=5,
max_total_rows=1000,
max_complexity=15
)
formatted_results = format_enhanced_results(results)
else:
print("I'm inside else condition")
return {
"functionResponse": {
"content": f"Error: Unknown function {function}"
}
}
# Format the response properly
response_body = {
'TEXT': {
'body': formatted_results
}
}
function_response = {
'functionResponse': {
'responseBody': response_body
}
}
return function_response
except Exception as e:
print(f"Error in lambda_handler: {str(e)}") # Add debugging
return {
"functionResponse": {
"content": f"Error inside the exception block: {str(e)}"
}
}