Commit 48e9789a authored by Yan Sushchynski (EPAM)'s avatar Yan Sushchynski (EPAM) Committed by Siarhei Khaletski (EPAM)
Browse files

Add BaseTokenRefresher

parent 2d1ee64d
Pipeline #33124 passed with stages
in 8 minutes and 30 seconds
......@@ -15,15 +15,9 @@
"""Auth and refresh token utility functions."""
import json
import logging
import os
from abc import ABC, abstractmethod
from functools import partial
from http import HTTPStatus
import requests
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
from osdu_api.libs.auth.authorization import TokenRefresher
from providers import credentials
from providers.types import BaseCredentials
from tenacity import retry, stop_after_attempt
......@@ -33,14 +27,12 @@ logger = logging.getLogger(__name__)
RETRIES = 3
class AirflowTokenRefresher(TokenRefresher):
"""Simple wrapper for credentials to be used in refresh_token decorator within Airflow."""
class BaseTokenRefresher(TokenRefresher):
"""Base Token refresher, that works with Credentials and has methods to refresh access tokens """
def __init__(self, creds: BaseCredentials = None):
super().__init__()
self._credentials = creds or credentials.get_credentials()
from airflow.models import Variable
self.airflow_variables = Variable
self._access_token = None
@retry(stop=stop_after_attempt(RETRIES))
def refresh_token(self) -> str:
......@@ -50,7 +42,6 @@ class AirflowTokenRefresher(TokenRefresher):
:rtype: str
"""
self._credentials.refresh_token()
self.airflow_variables.set("core__auth__access_token", self._credentials.access_token)
self._access_token = self._credentials.access_token
return self._access_token
......@@ -61,11 +52,6 @@ class AirflowTokenRefresher(TokenRefresher):
:return: The access token
:rtype: str
"""
if not self._access_token:
try:
self._access_token = self.airflow_variables.get("core__auth__access_token")
except KeyError:
self.refresh_token()
return self._access_token
@property
......@@ -76,3 +62,38 @@ class AirflowTokenRefresher(TokenRefresher):
:rtype: dict
"""
return {"Authorization": f"Bearer {self.access_token}"}
class AirflowTokenRefresher(BaseTokenRefresher):
"""Simple wrapper for credentials to be used in refresh_token decorator within Airflow."""
def __init__(self, creds: BaseCredentials = None):
super().__init__(creds)
from airflow.models import Variable
self.airflow_variables = Variable
self._access_token = None
@retry(stop=stop_after_attempt(RETRIES))
def refresh_token(self) -> str:
"""Refresh the token and cache token using airflow variables.
:return: The refreshed token
:rtype: str
"""
super().refresh_token()
self.airflow_variables.set("core__auth__access_token", self._access_token)
return self._access_token
@property
def access_token(self) -> str:
"""The access token.
:return: The access token
:rtype: str
"""
if not self._access_token:
try:
self._access_token = self.airflow_variables.get("core__auth__access_token")
except KeyError:
self.refresh_token()
return self._access_token
......@@ -23,10 +23,34 @@ from google.oauth2 import service_account
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
from libs.refresh_token import AirflowTokenRefresher
from libs.refresh_token import AirflowTokenRefresher, BaseTokenRefresher
from mock_providers import get_test_credentials
class TestBaseTokenRefresher:
@pytest.fixture()
def token_refresher(self, access_token: str) -> BaseTokenRefresher:
creds = get_test_credentials()
creds.access_token = access_token
token_refresher = BaseTokenRefresher(creds)
return token_refresher
@pytest.mark.parametrize(
"access_token",
[
"test",
"aaaa"
]
)
def test_authorization_header(self, token_refresher: BaseTokenRefresher, access_token: str):
"""
Check if Authorization header is 'Bearer <access_token>'
"""
token_refresher.refresh_token()
assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}"
class TestAirflowTokenRefresher:
@pytest.fixture()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment