Skip to content
Snippets Groups Projects
Commit 9a251e54 authored by Siarhei Khaletski (EPAM)'s avatar Siarhei Khaletski (EPAM) :triangular_flag_on_post:
Browse files

Merge branch 'trusted-feature/base-token-refresher' into 'master'

Added BaseTokenRefresher

See merge request !34
parents 2d1ee64d 48e9789a
No related branches found
No related tags found
1 merge request!34Added BaseTokenRefresher
Pipeline #33133 passed
...@@ -15,15 +15,9 @@ ...@@ -15,15 +15,9 @@
"""Auth and refresh token utility functions.""" """Auth and refresh token utility functions."""
import json
import logging import logging
import os
from abc import ABC, abstractmethod
from functools import partial
from http import HTTPStatus
import requests from osdu_api.libs.auth.authorization import TokenRefresher
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
from providers import credentials from providers import credentials
from providers.types import BaseCredentials from providers.types import BaseCredentials
from tenacity import retry, stop_after_attempt from tenacity import retry, stop_after_attempt
...@@ -33,14 +27,12 @@ logger = logging.getLogger(__name__) ...@@ -33,14 +27,12 @@ logger = logging.getLogger(__name__)
RETRIES = 3 RETRIES = 3
class AirflowTokenRefresher(TokenRefresher): class BaseTokenRefresher(TokenRefresher):
"""Simple wrapper for credentials to be used in refresh_token decorator within Airflow.""" """Base Token refresher, that works with Credentials and has methods to refresh access tokens """
def __init__(self, creds: BaseCredentials = None): def __init__(self, creds: BaseCredentials = None):
super().__init__()
self._credentials = creds or credentials.get_credentials() self._credentials = creds or credentials.get_credentials()
from airflow.models import Variable
self.airflow_variables = Variable
self._access_token = None
@retry(stop=stop_after_attempt(RETRIES)) @retry(stop=stop_after_attempt(RETRIES))
def refresh_token(self) -> str: def refresh_token(self) -> str:
...@@ -50,7 +42,6 @@ class AirflowTokenRefresher(TokenRefresher): ...@@ -50,7 +42,6 @@ class AirflowTokenRefresher(TokenRefresher):
:rtype: str :rtype: str
""" """
self._credentials.refresh_token() self._credentials.refresh_token()
self.airflow_variables.set("core__auth__access_token", self._credentials.access_token)
self._access_token = self._credentials.access_token self._access_token = self._credentials.access_token
return self._access_token return self._access_token
...@@ -61,11 +52,6 @@ class AirflowTokenRefresher(TokenRefresher): ...@@ -61,11 +52,6 @@ class AirflowTokenRefresher(TokenRefresher):
:return: The access token :return: The access token
:rtype: str :rtype: str
""" """
if not self._access_token:
try:
self._access_token = self.airflow_variables.get("core__auth__access_token")
except KeyError:
self.refresh_token()
return self._access_token return self._access_token
@property @property
...@@ -76,3 +62,38 @@ class AirflowTokenRefresher(TokenRefresher): ...@@ -76,3 +62,38 @@ class AirflowTokenRefresher(TokenRefresher):
:rtype: dict :rtype: dict
""" """
return {"Authorization": f"Bearer {self.access_token}"} return {"Authorization": f"Bearer {self.access_token}"}
class AirflowTokenRefresher(BaseTokenRefresher):
"""Simple wrapper for credentials to be used in refresh_token decorator within Airflow."""
def __init__(self, creds: BaseCredentials = None):
super().__init__(creds)
from airflow.models import Variable
self.airflow_variables = Variable
self._access_token = None
@retry(stop=stop_after_attempt(RETRIES))
def refresh_token(self) -> str:
"""Refresh the token and cache token using airflow variables.
:return: The refreshed token
:rtype: str
"""
super().refresh_token()
self.airflow_variables.set("core__auth__access_token", self._access_token)
return self._access_token
@property
def access_token(self) -> str:
"""The access token.
:return: The access token
:rtype: str
"""
if not self._access_token:
try:
self._access_token = self.airflow_variables.get("core__auth__access_token")
except KeyError:
self.refresh_token()
return self._access_token
...@@ -23,10 +23,34 @@ from google.oauth2 import service_account ...@@ -23,10 +23,34 @@ from google.oauth2 import service_account
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
from libs.refresh_token import AirflowTokenRefresher from libs.refresh_token import AirflowTokenRefresher, BaseTokenRefresher
from mock_providers import get_test_credentials from mock_providers import get_test_credentials
class TestBaseTokenRefresher:
@pytest.fixture()
def token_refresher(self, access_token: str) -> BaseTokenRefresher:
creds = get_test_credentials()
creds.access_token = access_token
token_refresher = BaseTokenRefresher(creds)
return token_refresher
@pytest.mark.parametrize(
"access_token",
[
"test",
"aaaa"
]
)
def test_authorization_header(self, token_refresher: BaseTokenRefresher, access_token: str):
"""
Check if Authorization header is 'Bearer <access_token>'
"""
token_refresher.refresh_token()
assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}"
class TestAirflowTokenRefresher: class TestAirflowTokenRefresher:
@pytest.fixture() @pytest.fixture()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment