Skip to content
Snippets Groups Projects
refresh_token.py 5.15 KiB
Newer Older
  • Learn to ignore specific revisions
  • #  Copyright 2020 Google LLC
    #  Copyright 2020 EPAM Systems
    #
    #  Licensed under the Apache License, Version 2.0 (the "License");
    #  you may not use this file except in compliance with the License.
    #  You may obtain a copy of the License at
    #
    #      http://www.apache.org/licenses/LICENSE-2.0
    #
    #  Unless required by applicable law or agreed to in writing, software
    #  distributed under the License is distributed on an "AS IS" BASIS,
    #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    #  See the License for the specific language governing permissions and
    #  limitations under the License.
    
    import logging
    import sys
    
    import time
    from functools import partial
    
    from http import HTTPStatus
    
    import requests
    from airflow.models import Variable
    from google.auth.transport.requests import Request
    from google.oauth2 import service_account
    
    from libs.exceptions import RefreshSATokenError
    
    from tenacity import retry, stop_after_attempt
    
    ACCESS_TOKEN = None
    
    # Set up base logger
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(
        logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s]  %(message)s"))
    logger = logging.getLogger("Dataload")
    logger.setLevel(logging.INFO)
    logger.addHandler(handler)
    
    RETRIES = 3
    
    SA_FILE_PATH = Variable.get("sa-file-osdu")
    ACCESS_SCOPES = ['openid', 'email', 'profile']
    
    
    
    @retry(stop=stop_after_attempt(RETRIES))
    
    def get_access_token(sa_file: str, scopes: list) -> str:
        """
        Refresh access token.
        """
    
        try:
            credentials = service_account.Credentials.from_service_account_file(
                sa_file, scopes=scopes)
        except ValueError as e:
            logger.error("SA file has bad format.")
            raise e
    
        logger.info("Refresh token.")
        credentials.refresh(Request())
        token = credentials.token
    
        if credentials.token is None:
            logger.error("Can't refresh token using SA-file. Token is empty.")
            raise RefreshSATokenError
        logger.info("Token refreshed.")
    
        return token
    
    
    @retry(stop=stop_after_attempt(RETRIES))
    def set_access_token(sa_file: str, scopes: list) -> str:
        """
        Create token
        """
        global ACCESS_TOKEN
        token = get_access_token(sa_file, scopes)
        auth = f"Bearer {token}"
        ACCESS_TOKEN = token
        Variable.set("access_token", token)
        return auth
    
    
    def _check_token():
        global ACCESS_TOKEN
        try:
            if not ACCESS_TOKEN:
                ACCESS_TOKEN = Variable.get('access_token')
        except KeyError:
            set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
    
    
    
    def make_callable_request(obj, request_function, headers, *args, **kwargs):
        """
        Create send_request_with_auth function.
        """
        headers["Authorization"] = f"Bearer {ACCESS_TOKEN}"
        if obj:  # if wrapped function is an object's method
            callable_request = lambda: request_function(obj, headers, *args, **kwargs)
        else:
            callable_request = lambda: request_function(headers, *args, **kwargs)
        return callable_request
    
    
    
    def _wrapper(*args, **kwargs):
        """
        Generic decorator wrapper for checking token and refreshing it.
        """
        _check_token()
        obj = kwargs.pop("obj") if kwargs.get("obj") else None
        headers = kwargs.pop("headers")
    
        request_function = kwargs.pop("request_function")
    
        if not isinstance(headers, dict):
            logger.error("Got headers %s" % headers)
            raise TypeError
    
        send_request_with_auth = make_callable_request(obj, request_function, headers,
                                                       *args, **kwargs)
    
        response = send_request_with_auth()
        if not isinstance(response, requests.Response):
            logger.error("Function %s must return values of type requests.Response. "
    
                         "Got %s instead" % (request_function, type(response)))
    
        if not response.ok:
            if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
                set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
    
                send_request_with_auth =  make_callable_request(obj,
                                                                request_function,
                                                                headers,
                                                                *args, **kwargs)
    
                response = send_request_with_auth()
    
            response.raise_for_status()
    
    def refresh_token(request_function):
    
        """
        Wrap a request function and check response. If response's error status code
        is about Authorization, refresh token and invoke this function once again.
        Expects function:
    
        If response is not ok and not about Authorization, then raises HTTPError
    
        request_func(header: dict, *args, **kwargs) -> requests.Response
        Or method:
        request_method(self, header: dict, *args, **kwargs) -> requests.Response
        """
    
        is_method = len(request_function.__qualname__.split(".")) > 1
    
        if is_method:
            def wrapper(obj, headers, *args, **kwargs):
    
                return _wrapper(request_function=request_function, obj=obj, headers=headers, *args,
                                **kwargs)
    
        else:
            def wrapper(headers, *args, **kwargs):
    
                return _wrapper(request_function=request_function, headers=headers, *args, **kwargs)