Skip to content
Snippets Groups Projects
test_refresh_token.py 3.53 KiB
Newer Older
  • Learn to ignore specific revisions
  • #  Copyright 2020 Google LLC
    #  Copyright 2020 EPAM Systems
    #
    #  Licensed under the Apache License, Version 2.0 (the "License");
    #  you may not use this file except in compliance with the License.
    #  You may obtain a copy of the License at
    #
    #      http://www.apache.org/licenses/LICENSE-2.0
    #
    #  Unless required by applicable law or agreed to in writing, software
    #  distributed under the License is distributed on an "AS IS" BASIS,
    #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    #  See the License for the specific language governing permissions and
    #  limitations under the License.
    
    
    import json
    
    import os
    import sys
    
    import pytest
    
    from google.oauth2 import service_account
    
    
    sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
    
    
    from libs.refresh_token import AirflowTokenRefresher, BaseTokenRefresher
    
    from mock_providers import get_test_credentials
    
    class TestBaseTokenRefresher:
    
        @pytest.fixture()
        def token_refresher(self, access_token: str) -> BaseTokenRefresher:
            creds = get_test_credentials()
            creds.access_token = access_token
            token_refresher = BaseTokenRefresher(creds)
            return token_refresher
    
        @pytest.mark.parametrize(
            "access_token",
            [
                "test",
                "aaaa"
            ]
        )
        def test_authorization_header(self, token_refresher: BaseTokenRefresher, access_token: str):
            """
            Check if Authorization header is 'Bearer <access_token>'
            """
            token_refresher.refresh_token()
            assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}"
    
    
    
    class TestAirflowTokenRefresher:
    
        @pytest.fixture()
    
        def token_refresher(self, access_token: str) -> AirflowTokenRefresher:
            creds = get_test_credentials()
            creds.access_token = access_token
            token_refresher = AirflowTokenRefresher(creds)
    
            return token_refresher
    
        @pytest.mark.parametrize(
            "access_token",
            [
                "test"
            ]
        )
    
        def test_access_token_cached(self, token_refresher: AirflowTokenRefresher, access_token: str):
    
            """
            Check if access token stored in Airflow Variables after refreshing it.
            """
            token_refresher.refresh_token()
    
            assert token_refresher.airflow_variables.get("core__auth__access_token") == access_token
    
    
        @pytest.mark.parametrize(
            "access_token",
            [
                "test",
                "aaaa"
            ]
        )
    
        def test_authorization_header(self, token_refresher: AirflowTokenRefresher, access_token: str):
    
            """
            Check if Authorization header is 'Bearer <access_token>'
            """
            token_refresher.refresh_token()
            assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}"
    
        @pytest.mark.parametrize(
    
                pytest.param("test1"),
                pytest.param("test2"),
    
        def test_refresh_token_no_cached_variable(
    
            self,
            monkeypatch,
    
            token_refresher: AirflowTokenRefresher,
    
        ):
            """
            Check if token refreshes automatically if key is not stored in cache.
            """
    
            def mock_empty_variable(key):
                """
                Raise error as if Airflow Variable don't have 'access_token' variable.
                """
                raise KeyError
    
            monkeypatch.setattr(token_refresher.airflow_variables, "get", mock_empty_variable)
    
            expected_access_token = token_refresher.access_token
            assert expected_access_token == access_token