Validate Cognito Token

How can we validate a Cognito token that is passed to an API endpoint?

These code snippets demonstrate how the chalice server validates a Cognito JWT that is passed to an endpoint. After retrieving the JWT from the request, it uses the pre-loaded Cognito configuration (that will be specific to your environment) to configure a JwtUtility (that is cobbled together from code from AWS and that uses the python-jose project) and then to have that utility validate the token for us.

--

app.py

--

from chalice import Chalice

from chalicelib import requestUtility

app = Chalice(app_name='rosetta-api-chalice')

#
# Convenience functions
#

def validateCognitoToken(requestHeaders):
    return requestUtility.RequestUtility().validateCognitoToken(app, requestHeaders)

--

chalicelib/requestUtility.py

--

import json
import os

from chalice import BadRequestError

from . jwtUtility import JwtUtility

# preload config file so it is not loaded for every call
configFilePath = os.path.join(os.path.dirname(__file__), 'config.json')
configFileJson = {}
with open(configFilePath) as configFile:
    configFileJson = json.load(configFile)

class RequestUtility(object):

    # region - Public Methods

    def validateCognitoToken(self, app, requestHeaders):
        # guard clause - no bearer token value
        cognitoTokenValue = self._getBearerTokenValue(requestHeaders)
        if cognitoTokenValue is None:
            raise BadRequestError('No cognito token provided')

        # guard clause - invalid token value
        cognitoPublicKeys = configFileJson.get('cognito_public_keys', [])
        cognitoPublicKeyMap = self._mapCognitoPublicKeys(cognitoPublicKeys)
        cognitoAppClientId = configFileJson.get('cognito_app_client_id', '')
        cognitoUserPoolId = configFileJson.get('cognito_user_pool_id', '')
        jwtUtility = JwtUtility(cognitoPublicKeyMap, cognitoAppClientId, cognitoUserPoolId)
        cognitoTokenClaims = jwtUtility.verifyIdJwt(cognitoTokenValue)
        if cognitoTokenClaims is None:
            raise BadRequestError('Invalid cognito token provided')

        # guard clause - invalid token claims
        cognitoTokenEmail = cognitoTokenClaims.get('email')
        if cognitoTokenEmail is None:
            raise BadRequestError('Invalid cognito token claims provided')

        # otherwise, return the email
        return cognitoTokenEmail

    # region - Helper Methods

    def _getBearerTokenValue(self, requestHeaders):
        # guard clause - no headers
        if requestHeaders is None:
            return None

        # guard clause - no authorization header or bearer token value
        authorizationValue = requestHeaders.get('Authorization')
        if authorizationValue is None:
            return None
        elif not authorizationValue.startswith('Bearer '):
            return None

        return authorizationValue.replace('Bearer ', '')

    def _mapCognitoPublicKeys(self, cognitoPublicKeys):
        publicKeyMap = {}
        for publicKey in cognitoPublicKeys:
            key = publicKey.get('kid')
            if key is not None:
                publicKeyMap[key] = publicKey

        return publicKeyMap

--

chalicelib/jwtUtility.py

--

# Adapted from:
# https://github.com/awslabs/aws-support-tools/blob/master/Cognito/decode-verify-jwt/decode-verify-jwt.py

# Identified by:
# https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html

import time

from jose import jwk, jwt
from jose.utils import base64url_decode

class JwtUtility:

    # region - Constructor
    
    def __init__(self, publicKeyMap, appClientId=None, userPoolId=None): 
        self._appClientId = appClientId
        self._publicKeyMap = publicKeyMap
        self._userPoolId = userPoolId
        
    # region - Public Methods
    
    def verifyIdJwt(self, token):
      # guard clause - no matching public key info
      headers = jwt.get_unverified_headers(token)
      publicKeyId = headers['kid']
      publicKeyJson = self._publicKeyMap.get(publicKeyId)
      if publicKeyJson is None:
          print('Public key {0} not found in public key map'.format(publicKeyId))
          for publicKey in self._publicKeyMap.keys():
              print('Public key in public key map: {0}'.format(publicKey))
          return None
      
      # guard clause - invalid signature
      publicKey = jwk.construct(publicKeyJson)
      rawTokenMessage, rawTokenSignature = str(token).rsplit('.', 1)
      encodedTokenMessage = rawTokenMessage.encode('utf-8')
      encodedTokenSignature = rawTokenSignature.encode('utf-8')
      decodedTokenSignature = base64url_decode(encodedTokenSignature)
      if not publicKey.verify(encodedTokenMessage, decodedTokenSignature):
          print('Signature verification failed')
          return None
          
      # guard clause - expired claims
      claims = jwt.get_unverified_claims(token)
      claimsExpiration = claims['exp']
      currentTime = time.time()
      validWindow = claimsExpiration - currentTime
      if validWindow <= 0:
          print('Token expired on {0} ({1})({2})'.format(claimsExpiration, currentTime, validWindow))
          return None
          
      # guard clause - mismatched audience
      claimsAudience = claims['aud']
      if self._appClientId is not None and claimsAudience != self._appClientId:
          print('Token issued for audience {0} ({1})'.format(claimsAudience, self._appClientId))
          return None
          
      # guard clause - mismatched issuer
      claimsIssuer = claims['iss']
      if self._userPoolId is not None and claimsIssuer != self._userPoolId:
          print('Token issued for user pool {0} ({1})'.format(claimsIssuer, self._userPoolId))
          return None
          
      # guard clause - mismatched use
      claimsTokenUse = claims['token_use']
      if claimsTokenUse != 'id':
          print('Token issued for use {0} (rather than id)'.format(claimsTokenUse))
          return None
          
      # otherwise, return the valid claims
      return claims

--