Source code for adal.token_cache

#------------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. 
# All rights reserved.
# 
# This code is licensed under the MIT License.
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files(the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions :
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#------------------------------------------------------------------------------

import json
import threading

from .constants import TokenResponseFields

def _string_cmp(str1, str2):
    '''Case insensitive comparison. Return true if both are None'''
    str1 = str1 if str1 is not None else ''
    str2 = str2 if str2 is not None else ''
    return str1.lower() == str2.lower()

class TokenCacheKey(object): # pylint: disable=too-few-public-methods
    def __init__(self, authority, resource, client_id, user_id):
        self.authority = authority
        self.resource = resource
        self.client_id = client_id
        self.user_id = user_id

    def __hash__(self):
        return hash((self.authority, self.resource, self.client_id, self.user_id))

    def __eq__(self, other):
        return _string_cmp(self.authority, other.authority) and \
               _string_cmp(self.resource, other.resource) and \
               _string_cmp(self.client_id, other.client_id) and \
               _string_cmp(self.user_id, other.user_id)

    def __ne__(self, other):
        return not self == other

# pylint: disable=protected-access

def _get_cache_key(entry):
    return TokenCacheKey(
        entry.get(TokenResponseFields._AUTHORITY), 
        entry.get(TokenResponseFields.RESOURCE), 
        entry.get(TokenResponseFields._CLIENT_ID), 
        entry.get(TokenResponseFields.USER_ID))


[docs]class TokenCache(object): def __init__(self, state=None): self._cache = {} self._lock = threading.RLock() if state: self.deserialize(state) self.has_state_changed = False
[docs] def find(self, query): with self._lock: return self._query_cache( query.get(TokenResponseFields.IS_MRRT), query.get(TokenResponseFields.USER_ID), query.get(TokenResponseFields._CLIENT_ID))
[docs] def remove(self, entries): with self._lock: for e in entries: key = _get_cache_key(e) removed = self._cache.pop(key, None) if removed is not None: self.has_state_changed = True
[docs] def add(self, entries): with self._lock: for e in entries: key = _get_cache_key(e) self._cache[key] = e self.has_state_changed = True
[docs] def serialize(self): with self._lock: return json.dumps(list(self._cache.values()))
[docs] def deserialize(self, state): with self._lock: self._cache.clear() if state: tokens = json.loads(state) for t in tokens: key = _get_cache_key(t) self._cache[key] = t
[docs] def read_items(self): '''output list of tuples in (key, authentication-result)''' with self._lock: return self._cache.items()
def _query_cache(self, is_mrrt, user_id, client_id): matches = [] for k in self._cache: v = self._cache[k] #None value will be taken as wildcard match #pylint: disable=too-many-boolean-expressions if ((is_mrrt is None or is_mrrt == v.get(TokenResponseFields.IS_MRRT)) and (user_id is None or _string_cmp(user_id, v.get(TokenResponseFields.USER_ID))) and (client_id is None or _string_cmp(client_id, v.get(TokenResponseFields._CLIENT_ID)))): matches.append(v) return matches