Newer
Older
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import time
from functools import partial
from http import HTTPStatus
import requests
from airflow.models import Variable
from google.auth.transport.requests import Request
from google.oauth2 import service_account
from libs.exceptions import RefreshSATokenError
from tenacity import retry, stop_after_attempt
ACCESS_TOKEN = None
# Set up base logger
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s"))
logger = logging.getLogger("Dataload")
logger.setLevel(logging.INFO)
logger.addHandler(handler)
RETRIES = 3
SA_FILE_PATH = Variable.get("sa-file-osdu")
ACCESS_SCOPES = ['openid', 'email', 'profile']
@retry(stop=stop_after_attempt(RETRIES))
def get_access_token(sa_file: str, scopes: list) -> str:
"""
Refresh access token.
"""
try:
credentials = service_account.Credentials.from_service_account_file(
sa_file, scopes=scopes)
except ValueError as e:
logger.error("SA file has bad format.")
raise e
logger.info("Refresh token.")
credentials.refresh(Request())
token = credentials.token
if credentials.token is None:
logger.error("Can't refresh token using SA-file. Token is empty.")
raise RefreshSATokenError
logger.info("Token refreshed.")
return token
@retry(stop=stop_after_attempt(RETRIES))
def set_access_token(sa_file: str, scopes: list) -> str:
"""
Create token
"""
global ACCESS_TOKEN
token = get_access_token(sa_file, scopes)
auth = f"Bearer {token}"
ACCESS_TOKEN = token
Variable.set("access_token", token)
return auth
def _check_token():
global ACCESS_TOKEN
try:
if not ACCESS_TOKEN:
ACCESS_TOKEN = Variable.get('access_token')
except KeyError:
set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
def make_callable_request(obj, request_function, headers, *args, **kwargs):
"""
Create send_request_with_auth function.
"""
headers["Authorization"] = f"Bearer {ACCESS_TOKEN}"
if obj: # if wrapped function is an object's method
callable_request = lambda: request_function(obj, headers, *args, **kwargs)
else:
callable_request = lambda: request_function(headers, *args, **kwargs)
return callable_request
def _wrapper(*args, **kwargs):
"""
Generic decorator wrapper for checking token and refreshing it.
"""
_check_token()
obj = kwargs.pop("obj") if kwargs.get("obj") else None
headers = kwargs.pop("headers")
request_function = kwargs.pop("request_function")
if not isinstance(headers, dict):
logger.error("Got headers %s" % headers)
raise TypeError
send_request_with_auth = make_callable_request(obj, request_function, headers,
*args, **kwargs)
response = send_request_with_auth()
if not isinstance(response, requests.Response):
logger.error("Function %s must return values of type requests.Response. "
"Got %s instead" % (request_function, type(response)))
if not response.ok:
if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
send_request_with_auth = make_callable_request(obj,
request_function,
headers,
*args, **kwargs)
response = send_request_with_auth()
response.raise_for_status()
return response
def refresh_token(request_function):
"""
Wrap a request function and check response. If response's error status code
is about Authorization, refresh token and invoke this function once again.
Expects function:
If response is not ok and not about Authorization, then raises HTTPError
request_func(header: dict, *args, **kwargs) -> requests.Response
Or method:
request_method(self, header: dict, *args, **kwargs) -> requests.Response
"""
is_method = len(request_function.__qualname__.split(".")) > 1
if is_method:
def wrapper(obj, headers, *args, **kwargs):
return _wrapper(request_function=request_function, obj=obj, headers=headers, *args,
**kwargs)
else:
def wrapper(headers, *args, **kwargs):
return _wrapper(request_function=request_function, headers=headers, *args, **kwargs)