Eashan Kaushik 88e19eddc9
customer-support-assistant v1 (#103)
* feat(customer-support): updated code

* Delete 02-use-cases/customer-support-assistant/Dockerfile

Signed-off-by: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com>

* Update .gitignore

Signed-off-by: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com>

---------

Signed-off-by: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com>
2025-07-21 11:34:00 -04:00

147 lines
4.9 KiB
Python

import base64
import hashlib
import os
import uuid
import json
import jwt
from urllib.parse import urlencode
import requests
import streamlit as st
from streamlit_cookies_controller import CookieController
from scripts.utils import get_ssm_parameter
class AuthManager:
def __init__(self):
self.cognito_domain = get_ssm_parameter(
"/app/customersupport/agentcore/cognito_domain"
).replace("https://", "")
self.client_id = get_ssm_parameter(
"/app/customersupport/agentcore/web_client_id"
)
self.redirect_uri = "http://localhost:8501/"
self.scopes = "email openid profile"
self.cookies = CookieController()
def generate_pkce_pair(self):
code_verifier = (
base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8").rstrip("=")
)
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode("utf-8")
.rstrip("=")
)
return code_verifier, code_challenge
def logout(self):
self.cookies.remove("tokens")
# Clear session state
if "session_id" in st.session_state:
del st.session_state["session_id"]
if "messages" in st.session_state:
del st.session_state["messages"]
if "agent_arn" in st.session_state:
del st.session_state["agent_arn"]
if "pending_assistant" in st.session_state:
del st.session_state["pending_assistant"]
if "region" in st.session_state:
del st.session_state["region"]
logout_url = f"https://{self.cognito_domain}/logout?" + urlencode(
{"client_id": self.client_id, "logout_uri": self.redirect_uri}
)
st.markdown(
f'<meta http-equiv="refresh" content="0;url={logout_url}">',
unsafe_allow_html=True,
)
st.rerun()
def handle_oauth_callback(self):
query_params = st.query_params
if (
query_params.get("code")
and query_params.get("state")
and not self.cookies.get("tokens")
):
auth_code = query_params.get("code")
returned_state = query_params.get("state")
code_verifier = self.cookies.get("code_verifier")
state = self.cookies.get("oauth_state")
if not state:
st.stop()
if returned_state != state:
st.error("State mismatch - potential CSRF detected")
st.stop()
# Exchange authorization code for tokens
token_url = f"https://{self.cognito_domain}/oauth2/token"
data = {
"grant_type": "authorization_code",
"client_id": self.client_id,
"code": auth_code,
"redirect_uri": self.redirect_uri,
"code_verifier": code_verifier,
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(token_url, data=data, headers=headers)
if response.ok:
tokens = response.json()
self.cookies.set("tokens", json.dumps(tokens))
self.cookies.remove("code_verifier")
self.cookies.remove("code_challenge")
self.cookies.remove("oauth_state")
st.query_params.clear()
else:
st.error(
f"Failed to exchange token: {response.status_code} - {response.text}"
)
def get_login_url(self):
code_verifier, code_challenge = self.generate_pkce_pair()
self.cookies.set("code_verifier", code_verifier)
self.cookies.set("code_challenge", code_challenge)
state = str(uuid.uuid4())
self.cookies.set("oauth_state", state)
login_params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": self.scopes,
"code_challenge_method": "S256",
"code_challenge": self.cookies.get("code_challenge"),
"state": self.cookies.get("oauth_state"),
}
return (
f"https://{self.cognito_domain}/oauth2/authorize?{urlencode(login_params)}"
)
def is_authenticated(self):
return bool(self.cookies.get("tokens"))
def get_tokens(self):
tokens_data = self.cookies.get("tokens")
if not tokens_data:
return None
# Handle both string and dict cases
if isinstance(tokens_data, str):
return json.loads(tokens_data)
elif isinstance(tokens_data, dict):
return tokens_data
else:
return None
def get_user_claims(self):
tokens = self.get_tokens()
if tokens:
return jwt.decode(tokens["id_token"], options={"verify_signature": False})
return None