# import ast import sys from typing import Any, Optional import streamlit as st import requests import base64 import hashlib import os import uuid from urllib.parse import urlencode # from streamlit_cookies_manager import EncryptedCookieManager import json import jwt import time import re import urllib from scripts.utils import read_config, get_aws_region, get_ssm_parameter from streamlit_cookies_controller import CookieController # ==== Configuration ==== AGENT_NAME = "default" # crude way to parse args if len(sys.argv) > 1: for arg in sys.argv: if arg.startswith("--agent="): AGENT_NAME = arg.split("=")[1] COGNITO_DOMAIN = get_ssm_parameter( "/app/customersupport/agentcore/cognito_domain" ).replace("https://", "") CLIENT_ID = get_ssm_parameter("/app/customersupport/agentcore/web_client_id") REDIRECT_URI = "http://localhost:8501/" SCOPES = "email openid profile" # ==== Initialize cookies manager ==== cookies = CookieController() st.set_page_config(layout="wide") # if not cookies.ready(): # st.stop() # Wait for cookies to load # ==== PKCE Helpers ==== def generate_pkce_pair(): code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8").rstrip("=") code_challenge = ( base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) .decode("utf-8") .rstrip("=") ) return code_verifier, code_challenge # ==== Clickable URL Helpers ==== def make_urls_clickable(text): """Convert URLs in text to clickable HTML links.""" # Comprehensive URL regex pattern url_pattern = r"https?://(?:[-\w.])+(?:\:[0-9]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:\#(?:[\w.])*)?)?" def replace_url(match): url = match.group(0) # Clean URL and create clickable link with styling to match theme return f'{url}' return re.sub(url_pattern, replace_url, text) def create_safe_markdown_text(text, message_placeholder): safe_text = text.encode("utf-16", "surrogatepass").decode("utf-16") message_placeholder.markdown(safe_text, unsafe_allow_html=True) # ==== Logout function ==== def logout(): cookies.remove("tokens") # Clear cookies on logout as well (in case) # cookies.remove("code_verifier") # cookies.remove("code_challenge") # cookies.remove("oauth_state") # cookies.save() del st.session_state["session_id"] del st.session_state["messages"] del st.session_state["agent_arn"] del st.session_state["pending_assistant"] del st.session_state["region"] logout_url = f"https://{COGNITO_DOMAIN}/logout?" + urlencode( {"client_id": CLIENT_ID, "logout_uri": REDIRECT_URI} ) create_safe_markdown_text( f'', st ) st.rerun() # ==== Styles ==== st.markdown( """ """, unsafe_allow_html=True, ) # ==== Handle OAuth callback ==== query_params = st.query_params if query_params.get("code") and query_params.get("state") and not cookies.get("tokens"): auth_code = query_params.get("code") returned_state = query_params.get("state") code_verifier = cookies.get("code_verifier") state = cookies.get("oauth_state") print(f"Check state {cookies.get('oauth_state')} against {returned_state}") if not state: st.stop() else: if returned_state != state: st.error("State mismatch - potential CSRF detected") st.stop() # Exchange authorization code for tokens token_url = f"https://{COGNITO_DOMAIN}/oauth2/token" data = { "grant_type": "authorization_code", "client_id": CLIENT_ID, "code": auth_code, "redirect_uri": REDIRECT_URI, "code_verifier": code_verifier, } headers = {"Content-Type": "application/x-www-form-urlencoded"} response = requests.post(token_url, data=data, headers=headers) if response.ok: tokens = response.json() # st.success("Logged in successfully!") # Clear the cookies after login to avoid reuse of old code_verifier and state cookies.set("tokens", json.dumps(tokens)) cookies.remove("code_verifier") cookies.remove("code_challenge") cookies.remove("oauth_state") # cookies.save() st.query_params.clear() # st.rerun() else: st.error(f"Failed to exchange token: {response.status_code} - {response.text}") # ==== Sidebar with welcome, tokens, and logout ==== st.sidebar.title("Access Tokens") def invoke_endpoint( agent_arn: str, payload, session_id: str, bearer_token: Optional[str], # noqa: F821 endpoint_name: str = "DEFAULT", ) -> Any: """Invoke agent endpoint using HTTP request with bearer token. Args: agent_arn: Agent ARN to invoke payload: Payload to send (dict or string) session_id: Session ID for the request bearer_token: Bearer token for authentication endpoint_name: Endpoint name, defaults to "DEFAULT" Returns: Response from the agent endpoint """ # Escape agent ARN for URL escaped_arn = urllib.parse.quote(agent_arn, safe="") # Build URL url = f"https://bedrock-agentcore.{st.session_state['region']}.amazonaws.com/runtimes/{escaped_arn}/invocations" # Headers headers = { "Authorization": f"Bearer {bearer_token}", "Content-Type": "application/json", "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, } # Parse the payload string back to JSON object to send properly # This ensures consistent payload structure between boto3 and HTTP clients try: body = json.loads(payload) if isinstance(payload, str) else payload except json.JSONDecodeError: # Fallback for non-JSON strings - wrap in payload object body = {"payload": payload} try: # Make request with timeout response = requests.post( url, params={"qualifier": endpoint_name}, headers=headers, json=body, timeout=100, stream=True, ) last_data = False for line in response.iter_lines(chunk_size=1): if line: line = line.decode("utf-8") if line.startswith("data: "): last_data = True line = line[6:] yield line elif line: if last_data: yield "\n" + line last_data = False except requests.exceptions.RequestException as e: print("Failed to invoke agent endpoint: %s", str(e)) raise # ==== Main app ==== if cookies.get("tokens"): st.sidebar.code(cookies.get("tokens")) if st.sidebar.button("Logout"): logout() if "session_id" not in st.session_state: st.session_state["session_id"] = str(uuid.uuid4()) if "agent_arn" not in st.session_state: runtime_config = read_config(".bedrock_agentcore.yaml") st.session_state["agent_arn"] = runtime_config["agents"][AGENT_NAME][ "bedrock_agentcore" ]["agent_arn"] if "region" not in st.session_state: st.session_state["region"] = get_aws_region() st.sidebar.write("Agent Arn") st.sidebar.code(st.session_state["agent_arn"]) st.sidebar.write("Session Id") st.sidebar.code(st.session_state["session_id"]) token = json.loads(cookies.get("tokens")) claims = jwt.decode(token["id_token"], options={"verify_signature": False}) st.title("Customer Support Assistant") st.markdown( """