diff --git a/src/dags/libs/refresh_token.py b/src/dags/libs/refresh_token.py index 8ab038dfbf637c1853a154e5432b4efa1b762573..c8672d09661ecd41497e5717352b5280b4a509cf 100644 --- a/src/dags/libs/refresh_token.py +++ b/src/dags/libs/refresh_token.py @@ -15,15 +15,9 @@ """Auth and refresh token utility functions.""" -import json import logging -import os -from abc import ABC, abstractmethod -from functools import partial -from http import HTTPStatus -import requests -from osdu_api.libs.auth.authorization import TokenRefresher, authorize +from osdu_api.libs.auth.authorization import TokenRefresher from providers import credentials from providers.types import BaseCredentials from tenacity import retry, stop_after_attempt @@ -33,14 +27,12 @@ logger = logging.getLogger(__name__) RETRIES = 3 -class AirflowTokenRefresher(TokenRefresher): - """Simple wrapper for credentials to be used in refresh_token decorator within Airflow.""" +class BaseTokenRefresher(TokenRefresher): + """Base Token refresher, that works with Credentials and has methods to refresh access tokens """ def __init__(self, creds: BaseCredentials = None): + super().__init__() self._credentials = creds or credentials.get_credentials() - from airflow.models import Variable - self.airflow_variables = Variable - self._access_token = None @retry(stop=stop_after_attempt(RETRIES)) def refresh_token(self) -> str: @@ -50,7 +42,6 @@ class AirflowTokenRefresher(TokenRefresher): :rtype: str """ self._credentials.refresh_token() - self.airflow_variables.set("core__auth__access_token", self._credentials.access_token) self._access_token = self._credentials.access_token return self._access_token @@ -61,11 +52,6 @@ class AirflowTokenRefresher(TokenRefresher): :return: The access token :rtype: str """ - if not self._access_token: - try: - self._access_token = self.airflow_variables.get("core__auth__access_token") - except KeyError: - self.refresh_token() return self._access_token @property @@ -76,3 +62,38 @@ class AirflowTokenRefresher(TokenRefresher): :rtype: dict """ return {"Authorization": f"Bearer {self.access_token}"} + + +class AirflowTokenRefresher(BaseTokenRefresher): + """Simple wrapper for credentials to be used in refresh_token decorator within Airflow.""" + + def __init__(self, creds: BaseCredentials = None): + super().__init__(creds) + from airflow.models import Variable + self.airflow_variables = Variable + self._access_token = None + + @retry(stop=stop_after_attempt(RETRIES)) + def refresh_token(self) -> str: + """Refresh the token and cache token using airflow variables. + + :return: The refreshed token + :rtype: str + """ + super().refresh_token() + self.airflow_variables.set("core__auth__access_token", self._access_token) + return self._access_token + + @property + def access_token(self) -> str: + """The access token. + + :return: The access token + :rtype: str + """ + if not self._access_token: + try: + self._access_token = self.airflow_variables.get("core__auth__access_token") + except KeyError: + self.refresh_token() + return self._access_token diff --git a/tests/plugin-unit-tests/test_refresh_token.py b/tests/plugin-unit-tests/test_refresh_token.py index f468c3c618465d5bfa846b8e65ee3fc7d526cb7e..fcfb0f4c266446388f0eacc75732dd96ec4147e9 100644 --- a/tests/plugin-unit-tests/test_refresh_token.py +++ b/tests/plugin-unit-tests/test_refresh_token.py @@ -23,10 +23,34 @@ from google.oauth2 import service_account sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") -from libs.refresh_token import AirflowTokenRefresher +from libs.refresh_token import AirflowTokenRefresher, BaseTokenRefresher from mock_providers import get_test_credentials +class TestBaseTokenRefresher: + + @pytest.fixture() + def token_refresher(self, access_token: str) -> BaseTokenRefresher: + creds = get_test_credentials() + creds.access_token = access_token + token_refresher = BaseTokenRefresher(creds) + return token_refresher + + @pytest.mark.parametrize( + "access_token", + [ + "test", + "aaaa" + ] + ) + def test_authorization_header(self, token_refresher: BaseTokenRefresher, access_token: str): + """ + Check if Authorization header is 'Bearer <access_token>' + """ + token_refresher.refresh_token() + assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}" + + class TestAirflowTokenRefresher: @pytest.fixture()