Newer
Older
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import pytest
from google.oauth2 import service_account
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
from libs.refresh_token import AirflowTokenRefresher, BaseTokenRefresher
from mock_providers import get_test_credentials
class TestBaseTokenRefresher:
@pytest.fixture()
def token_refresher(self, access_token: str) -> BaseTokenRefresher:
creds = get_test_credentials()
creds.access_token = access_token
token_refresher = BaseTokenRefresher(creds)
return token_refresher
@pytest.mark.parametrize(
"access_token",
[
"test",
"aaaa"
]
)
def test_authorization_header(self, token_refresher: BaseTokenRefresher, access_token: str):
"""
Check if Authorization header is 'Bearer <access_token>'
"""
token_refresher.refresh_token()
assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}"
class TestAirflowTokenRefresher:
@pytest.fixture()
def token_refresher(self, access_token: str) -> AirflowTokenRefresher:
creds = get_test_credentials()
creds.access_token = access_token
token_refresher = AirflowTokenRefresher(creds)
return token_refresher
@pytest.mark.parametrize(
"access_token",
[
"test"
]
)
def test_access_token_cached(self, token_refresher: AirflowTokenRefresher, access_token: str):
"""
Check if access token stored in Airflow Variables after refreshing it.
"""
token_refresher.refresh_token()
assert token_refresher.airflow_variables.get("core__auth__access_token") == access_token
@pytest.mark.parametrize(
"access_token",
[
"test",
"aaaa"
]
)
def test_authorization_header(self, token_refresher: AirflowTokenRefresher, access_token: str):
"""
Check if Authorization header is 'Bearer <access_token>'
"""
token_refresher.refresh_token()
assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}"
@pytest.mark.parametrize(
"access_token",
pytest.param("test1"),
pytest.param("test2"),
def test_refresh_token_no_cached_variable(
token_refresher: AirflowTokenRefresher,
access_token: str,
):
"""
Check if token refreshes automatically if key is not stored in cache.
"""
def mock_empty_variable(key):
"""
Raise error as if Airflow Variable don't have 'access_token' variable.
"""
raise KeyError
monkeypatch.setattr(token_refresher.airflow_variables, "get", mock_empty_variable)
expected_access_token = token_refresher.access_token
assert expected_access_token == access_token