Validate Cognito Token
How can we validate a Cognito token that is passed to an API endpoint?
These code snippets demonstrate how the chalice server validates a Cognito JWT that is passed to an endpoint. After retrieving the JWT from the request, it uses the pre-loaded Cognito configuration (that will be specific to your environment) to configure a JwtUtility (that is cobbled together from code from AWS and that uses the python-jose project) and then to have that utility validate the token for us.
--
app.py
--
from chalice import Chalice
from chalicelib import requestUtility
app = Chalice(app_name='rosetta-api-chalice')
#
# Convenience functions
#
def validateCognitoToken(requestHeaders):
return requestUtility.RequestUtility().validateCognitoToken(app, requestHeaders)
--
chalicelib/requestUtility.py
--
import json
import os
from chalice import BadRequestError
from . jwtUtility import JwtUtility
# preload config file so it is not loaded for every call
configFilePath = os.path.join(os.path.dirname(__file__), 'config.json')
configFileJson = {}
with open(configFilePath) as configFile:
configFileJson = json.load(configFile)
class RequestUtility(object):
# region - Public Methods
def validateCognitoToken(self, app, requestHeaders):
# guard clause - no bearer token value
cognitoTokenValue = self._getBearerTokenValue(requestHeaders)
if cognitoTokenValue is None:
raise BadRequestError('No cognito token provided')
# guard clause - invalid token value
cognitoPublicKeys = configFileJson.get('cognito_public_keys', [])
cognitoPublicKeyMap = self._mapCognitoPublicKeys(cognitoPublicKeys)
cognitoAppClientId = configFileJson.get('cognito_app_client_id', '')
cognitoUserPoolId = configFileJson.get('cognito_user_pool_id', '')
jwtUtility = JwtUtility(cognitoPublicKeyMap, cognitoAppClientId, cognitoUserPoolId)
cognitoTokenClaims = jwtUtility.verifyIdJwt(cognitoTokenValue)
if cognitoTokenClaims is None:
raise BadRequestError('Invalid cognito token provided')
# guard clause - invalid token claims
cognitoTokenEmail = cognitoTokenClaims.get('email')
if cognitoTokenEmail is None:
raise BadRequestError('Invalid cognito token claims provided')
# otherwise, return the email
return cognitoTokenEmail
# region - Helper Methods
def _getBearerTokenValue(self, requestHeaders):
# guard clause - no headers
if requestHeaders is None:
return None
# guard clause - no authorization header or bearer token value
authorizationValue = requestHeaders.get('Authorization')
if authorizationValue is None:
return None
elif not authorizationValue.startswith('Bearer '):
return None
return authorizationValue.replace('Bearer ', '')
def _mapCognitoPublicKeys(self, cognitoPublicKeys):
publicKeyMap = {}
for publicKey in cognitoPublicKeys:
key = publicKey.get('kid')
if key is not None:
publicKeyMap[key] = publicKey
return publicKeyMap
--
chalicelib/jwtUtility.py
--
# Adapted from:
# https://github.com/awslabs/aws-support-tools/blob/master/Cognito/decode-verify-jwt/decode-verify-jwt.py
# Identified by:
# https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
import time
from jose import jwk, jwt
from jose.utils import base64url_decode
class JwtUtility:
# region - Constructor
def __init__(self, publicKeyMap, appClientId=None, userPoolId=None):
self._appClientId = appClientId
self._publicKeyMap = publicKeyMap
self._userPoolId = userPoolId
# region - Public Methods
def verifyIdJwt(self, token):
# guard clause - no matching public key info
headers = jwt.get_unverified_headers(token)
publicKeyId = headers['kid']
publicKeyJson = self._publicKeyMap.get(publicKeyId)
if publicKeyJson is None:
print('Public key {0} not found in public key map'.format(publicKeyId))
for publicKey in self._publicKeyMap.keys():
print('Public key in public key map: {0}'.format(publicKey))
return None
# guard clause - invalid signature
publicKey = jwk.construct(publicKeyJson)
rawTokenMessage, rawTokenSignature = str(token).rsplit('.', 1)
encodedTokenMessage = rawTokenMessage.encode('utf-8')
encodedTokenSignature = rawTokenSignature.encode('utf-8')
decodedTokenSignature = base64url_decode(encodedTokenSignature)
if not publicKey.verify(encodedTokenMessage, decodedTokenSignature):
print('Signature verification failed')
return None
# guard clause - expired claims
claims = jwt.get_unverified_claims(token)
claimsExpiration = claims['exp']
currentTime = time.time()
validWindow = claimsExpiration - currentTime
if validWindow <= 0:
print('Token expired on {0} ({1})({2})'.format(claimsExpiration, currentTime, validWindow))
return None
# guard clause - mismatched audience
claimsAudience = claims['aud']
if self._appClientId is not None and claimsAudience != self._appClientId:
print('Token issued for audience {0} ({1})'.format(claimsAudience, self._appClientId))
return None
# guard clause - mismatched issuer
claimsIssuer = claims['iss']
if self._userPoolId is not None and claimsIssuer != self._userPoolId:
print('Token issued for user pool {0} ({1})'.format(claimsIssuer, self._userPoolId))
return None
# guard clause - mismatched use
claimsTokenUse = claims['token_use']
if claimsTokenUse != 'id':
print('Token issued for use {0} (rather than id)'.format(claimsTokenUse))
return None
# otherwise, return the valid claims
return claims
--