Godwin Vincent cd0a29d2ae
Device management agent - AgentCore runtime, observability, frontend added (#241)
* updated README.md file with bearer token generation

* updated README.md file with bearer token generation-removed client id and secret credentials

* removed hardcoded domain

* added agent runtime, frontend, observability and agentcore identity

* update README.md file to reflect frontend testing
2025-08-13 09:31:29 -07:00

151 lines
5.0 KiB
Python

"""
Authentication module for Cognito integration
"""
import os
import base64
import json
import logging
from urllib.parse import urlencode
from typing import Optional, Dict, Any
import httpx
from fastapi import Request, HTTPException, Depends
from fastapi.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
from jose import jwk, jwt
from jose.utils import base64url_decode
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Cognito configuration from environment variables
COGNITO_DOMAIN = os.getenv("COGNITO_DOMAIN")
COGNITO_CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
COGNITO_CLIENT_SECRET = os.getenv("COGNITO_CLIENT_SECRET")
COGNITO_REDIRECT_URI = os.getenv("COGNITO_REDIRECT_URI")
COGNITO_LOGOUT_URI = os.getenv("COGNITO_LOGOUT_URI")
AWS_REGION = os.getenv("AWS_REGION")
COGNITO_USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID")
# JWT validation
jwks_url = f"https://cognito-idp.{AWS_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}/.well-known/jwks.json"
jwks = None
async def get_jwks():
"""Fetch the JSON Web Key Set from Cognito"""
global jwks
if jwks is None:
async with httpx.AsyncClient() as client:
response = await client.get(jwks_url)
jwks = response.json()
return jwks
def get_login_url() -> str:
"""Generate the Cognito login URL"""
# Use the Cognito hosted UI directly
login_url = f"https://{COGNITO_DOMAIN}/login?client_id={COGNITO_CLIENT_ID}&response_type=code&redirect_uri={COGNITO_REDIRECT_URI}"
# Debug logging
logger.info(f"COGNITO_DOMAIN: {COGNITO_DOMAIN}")
logger.info(f"COGNITO_CLIENT_ID: {COGNITO_CLIENT_ID}")
logger.info(f"COGNITO_REDIRECT_URI: {COGNITO_REDIRECT_URI}")
logger.info(f"Full login URL: {login_url}")
return login_url
def get_logout_url() -> str:
"""Generate the Cognito logout URL"""
params = {
"client_id": COGNITO_CLIENT_ID,
"logout_uri": COGNITO_LOGOUT_URI
}
return f"https://{COGNITO_DOMAIN}/logout?{urlencode(params)}"
async def exchange_code_for_tokens(code: str) -> Dict[str, Any]:
"""Exchange authorization code for tokens"""
token_endpoint = f"https://{COGNITO_DOMAIN}/oauth2/token"
# Instead of using Authorization header, include client_id and client_secret in the form data
data = {
"grant_type": "authorization_code",
"client_id": COGNITO_CLIENT_ID,
"client_secret": COGNITO_CLIENT_SECRET,
"code": code,
"redirect_uri": COGNITO_REDIRECT_URI
}
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
logger.info(f"Exchanging code for tokens with client_id: {COGNITO_CLIENT_ID}")
async with httpx.AsyncClient() as client:
response = await client.post(token_endpoint, headers=headers, data=data)
if response.status_code != 200:
logger.error(f"Token exchange failed: {response.text}")
raise HTTPException(status_code=400, detail=f"Failed to exchange code for tokens: {response.text}")
return response.json()
async def validate_token(token: str) -> Dict[str, Any]:
"""Validate the JWT token from Cognito"""
# Get the key id from the token header
header = jwt.get_unverified_header(token)
kid = header["kid"]
# Get the public key that matches the key id
jwks_client = await get_jwks()
key = None
for jwk_key in jwks_client["keys"]:
if jwk_key["kid"] == kid:
key = jwk_key
break
if not key:
raise HTTPException(status_code=401, detail="Invalid token: Key not found")
# Verify the signature
hmac_key = jwk.construct(key)
message, encoded_signature = token.rsplit(".", 1)
decoded_signature = base64url_decode(encoded_signature.encode())
if not hmac_key.verify(message.encode(), decoded_signature):
raise HTTPException(status_code=401, detail="Invalid token: Signature verification failed")
# Verify the claims
claims = jwt.get_unverified_claims(token)
# Check expiration
import time
if claims["exp"] < time.time():
raise HTTPException(status_code=401, detail="Token expired")
# Check audience
if claims["client_id"] != COGNITO_CLIENT_ID:
raise HTTPException(status_code=401, detail="Invalid audience")
return claims
async def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
"""Get the current authenticated user from the session or simple cookie"""
# Check for Cognito session first
if "user" in request.session:
return request.session["user"]
# Check for simple login cookie as fallback
simple_user = request.cookies.get("simple_user")
if simple_user:
return {"username": simple_user, "auth_type": "simple"}
return None
def login_required(request: Request):
"""Dependency to check if user is logged in"""
user = request.session.get("user")
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user