diff --git a/src/dags/providers/azure/azure_credentials.py b/src/dags/providers/azure/azure_credentials.py index 6c52c4556dfb004281d1180509a826b61f9f27c1..9ced58f99b6e9a67b68a9fc505c30c0c4a6f5b9a 100644 --- a/src/dags/providers/azure/azure_credentials.py +++ b/src/dags/providers/azure/azure_credentials.py @@ -20,8 +20,8 @@ from providers.types import BaseCredentials from tenacity import retry, stop_after_attempt import msal import os -from azure.keyvault.secrets import SecretClient -from azure.identity import DefaultAzureCredential +from azure.keyvault import secrets +from azure import identity logger = logging.getLogger(__name__) RETRIES = 3 @@ -40,8 +40,8 @@ class AzureCredentials(BaseCredentials): def _populate_ad_credentials(self) -> None: uri = os.getenv("AIRFLOW_VAR_KEYVAULT_URI") - credential = DefaultAzureCredential() - client = SecretClient(vault_url=uri, credential=credential) + credential = identity.DefaultAzureCredential() + client = secrets.SecretClient(vault_url=uri, credential=credential) self._client_id = client.get_secret("app-dev-sp-username").value self._client_secret = client.get_secret('app-dev-sp-password').value self._tenant_id = client.get_secret('app-dev-sp-tenant-id').value diff --git a/tests/providers-unit-tests/test_azure_credentials.py b/tests/providers-unit-tests/test_azure_credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..9a017d2f4ed999512acdb1b90b30c5fc9190c379 --- /dev/null +++ b/tests/providers-unit-tests/test_azure_credentials.py @@ -0,0 +1,110 @@ +import os +import sys + +sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") +sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") + +import pytest +from providers.azure.azure_credentials import AzureCredentials +import msal +from azure import identity +from azure.keyvault import secrets + +CLIENT_ID = "someClientId" +CLIENT_SECRET = "someClientSecret" +TENANT_ID = "someTenantId" +RESOURCE_ID = "someResourceId" +TOKEN = "someToken" +AUTHORITY_URI = 'https://login.microsoftonline.com/' + TENANT_ID +KEY_VAULT_URL = 'https://keyvault.com' +SCOPES = ["someResourceId/.default"] + + +class MockConfidentialClientApplication: + def __init__(self, client_id: str, authority: str, client_credential: str): + assert client_id == CLIENT_ID + assert client_credential == CLIENT_SECRET + assert authority == AUTHORITY_URI + + def acquire_token_for_client(self, scopes: list): + assert scopes == SCOPES + return {"access_token": TOKEN} + + +class MockDefaultAzureCredentials: + def __init__(self): + pass + + +class MockSecret: + def __init__(self, value): + self._value = value + + @property + def value(self) -> str: + return self._value + + +class MockSecretClient: + def __init__(self, vault_url: str, credential): + assert vault_url == KEY_VAULT_URL + assert isinstance(credential, MockDefaultAzureCredentials) + + def get_secret(self, key: str): + if key == "app-dev-sp-username": + return MockSecret(CLIENT_ID) + elif key == "app-dev-sp-password": + return MockSecret(CLIENT_SECRET) + elif key == "app-dev-sp-tenant-id": + return MockSecret(TENANT_ID) + elif key == "aad-client-id": + return MockSecret(RESOURCE_ID) + else: + raise ValueError("Invalid Key") + + +class TestAzureCredentials: + @pytest.fixture() + def azure_credentials(self, monkeypatch, mock_credentials: bool) -> AzureCredentials: + azure_credentials = AzureCredentials() + if mock_credentials: + monkeypatch.setattr(azure_credentials, "_client_id", CLIENT_ID) + monkeypatch.setattr(azure_credentials, "_client_secret", CLIENT_SECRET) + monkeypatch.setattr(azure_credentials, "_tenant_id", TENANT_ID) + monkeypatch.setattr(azure_credentials, "_resource_id", RESOURCE_ID) + else: + monkeypatch.setenv("AIRFLOW_VAR_KEYVAULT_URI", KEY_VAULT_URL) + monkeypatch.setattr(identity, + "DefaultAzureCredential", + MockDefaultAzureCredentials) + monkeypatch.setattr(secrets, + "SecretClient", + MockSecretClient) + + monkeypatch.setattr(msal, + "ConfidentialClientApplication", + MockConfidentialClientApplication) + return azure_credentials + + @pytest.mark.parametrize("mock_credentials", [pytest.param(True)]) + def test_refresh_token_with_credential_information_available( + self, + azure_credentials: AzureCredentials, + mock_credentials: bool): + + """ + Checks if token is fetched properly if credential information is already available. + """ + assert azure_credentials.refresh_token() == TOKEN + + @pytest.mark.parametrize("mock_credentials", [pytest.param(False)]) + def test_refresh_token_with_credential_information_missing( + self, + azure_credentials: AzureCredentials, + mock_credentials: bool): + """ + Checks if credentials are fetched properly and then token is generated with fetched + credentials. + """ + assert azure_credentials.refresh_token() == TOKEN + diff --git a/tests/providers-unit-tests/test_credentials.py b/tests/providers-unit-tests/test_credentials.py index 2ffef7bc7cc45ab1d45575bcd7f921c997ee3114..aec934779835f010efa7a8b4d7e6bc486e0298c6 100644 --- a/tests/providers-unit-tests/test_credentials.py +++ b/tests/providers-unit-tests/test_credentials.py @@ -24,6 +24,7 @@ import providers.credentials from providers.credentials import get_credentials from providers.types import BaseCredentials from providers.gcp.gcp_credentials import GCPCredentials +from providers.azure.azure_credentials import AzureCredentials DATA_PATH_PREFIX = f"{os.path.dirname(__file__)}/data" @@ -42,6 +43,7 @@ class TestGetCredentials: @pytest.mark.parametrize("provider, instance_type", [ pytest.param("gcp", GCPCredentials), + pytest.param("azure", AzureCredentials) ]) def test_get_credentials_inferred_env(self, monkeypatch, mock_os_environ, provider: str, instance_type: BaseCredentials): @@ -51,6 +53,7 @@ class TestGetCredentials: @pytest.mark.parametrize("provider, instance_type", [ pytest.param("gcp", GCPCredentials), + pytest.param("azure", AzureCredentials) ]) def test_get_credentials_explicit_env(self, monkeypatch, mock_os_environ, provider: str, instance_type: BaseCredentials):