Commit 82ad8ba8 authored by Yannick's avatar Yannick
Browse files

back to synchronous partition id

parent c6b115a7
Pipeline #45242 passed with stage
in 34 seconds
import logging import logging
from azure.identity.aio import DefaultAzureCredential from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets.aio import SecretClient from azure.keyvault.secrets import SecretClient
from osdu_az import conf from osdu_az import conf
...@@ -20,33 +20,25 @@ class AzureIdentity: ...@@ -20,33 +20,25 @@ class AzureIdentity:
return cls.default_credential return cls.default_credential
@classmethod @classmethod
async def close_credentials(cls): def get_access_token(cls):
""" This cause to gracefully dispose credentials if any. Next calls will then initialize a new one """
_LOGGER.info(f"Closing cached Credentials")
credentials_to_close, cls.default_credential = cls.default_credential, None # swap
if credentials_to_close is not None:
await credentials_to_close.close()
@classmethod
async def get_access_token(cls):
credential = cls.get_default_credential() credential = cls.get_default_credential()
scope = await cls._get_scope() scope = cls._get_scope()
access_token = await credential.get_token(scope) access_token = credential.get_token(scope)
return access_token.token return access_token.token
@classmethod @classmethod
async def _get_scope(cls): def _get_scope(cls):
if not cls.default_scope: if not cls.default_scope:
cls.default_scope = await cls._get_resource_id() cls.default_scope = cls._get_resource_id()
return cls.default_scope return cls.default_scope
@classmethod @classmethod
async def _get_resource_id(cls) -> str: def _get_resource_id(cls) -> str:
return await cls.get_secret('aad-client-id') return cls.get_secret('aad-client-id')
@classmethod @classmethod
async def get_secret(cls, name) -> str: def get_secret(cls, name) -> str:
if cls._secret_client is None: if cls._secret_client is None:
cls._secret_client = SecretClient(conf.keyvault_url, cls.get_default_credential()) cls._secret_client = SecretClient(conf.keyvault_url, cls.get_default_credential())
secret = await cls._secret_client.get_secret(name) secret = cls._secret_client.get_secret(name)
return secret.value return secret.value
...@@ -20,7 +20,7 @@ class PartitionClient: ...@@ -20,7 +20,7 @@ class PartitionClient:
@staticmethod @staticmethod
async def get_partition(data_partition_id: str) -> Optional[PartitionInfo]: async def get_partition(data_partition_id: str) -> Optional[PartitionInfo]:
access_token = await AzureIdentity.get_access_token() access_token = AzureIdentity.get_access_token()
headers = { headers = {
'authorization': f'Bearer {access_token}' 'authorization': f'Bearer {access_token}'
......
...@@ -12,7 +12,7 @@ class PartitionInfo: ...@@ -12,7 +12,7 @@ class PartitionInfo:
def __init__(self, partition_properties: dict = None): def __init__(self, partition_properties: dict = None):
self._partition_properties = partition_properties self._partition_properties = partition_properties
async def get_value(self, property_name: str) -> Optional[str]: def get_value(self, property_name: str) -> Optional[str]:
partition_property = self._partition_properties.get(property_name) partition_property = self._partition_properties.get(property_name)
if not partition_property: if not partition_property:
return None return None
...@@ -21,12 +21,12 @@ class PartitionInfo: ...@@ -21,12 +21,12 @@ class PartitionInfo:
return partition_property['value'] return partition_property['value']
if 'secret' not in partition_property: if 'secret' not in partition_property:
partition_property['secret'] = await self._get_secret(partition_property['value']) partition_property['secret'] = self._get_secret(partition_property['value'])
return partition_property['secret'] return partition_property['secret']
async def _get_secret(self, key: str) -> str: def _get_secret(self, key: str) -> str:
ts = datetime.utcnow() ts = datetime.utcnow()
secret = await AzureIdentity.get_secret(key) secret = AzureIdentity.get_secret(key)
_LOGGER.info(f'PartitionInfo get secret took {(datetime.utcnow() - ts).total_seconds()} ms') _LOGGER.info(f'PartitionInfo get secret took {(datetime.utcnow() - ts).total_seconds()} ms')
return secret return secret
......
...@@ -45,7 +45,7 @@ class PartitionService: ...@@ -45,7 +45,7 @@ class PartitionService:
async def get_storage_account_name(data_partition_id: str): async def get_storage_account_name(data_partition_id: str):
partition_info = await PartitionService.get_partition(data_partition_id) partition_info = await PartitionService.get_partition(data_partition_id)
if partition_info: if partition_info:
return await partition_info.get_value(STORAGE_ACCOUNT_NAME) return partition_info.get_value(STORAGE_ACCOUNT_NAME)
@staticmethod @staticmethod
def _partition_client(): def _partition_client():
......
...@@ -2,7 +2,7 @@ from cachetools import TTLCache ...@@ -2,7 +2,7 @@ from cachetools import TTLCache
class PartitionsCache: class PartitionsCache:
partitions_cache = TTLCache(maxsize=100, ttl=300) partitions_cache = TTLCache(maxsize=100, ttl=3600)
@staticmethod @staticmethod
def get(data_partition_id: str): def get(data_partition_id: str):
......
...@@ -60,8 +60,8 @@ class AzureAioBlobStorage(BlobStorageBase): ...@@ -60,8 +60,8 @@ class AzureAioBlobStorage(BlobStorageBase):
async def _get_blob_service_client(self, tenant): async def _get_blob_service_client(self, tenant):
partition_info = await PartitionService.get_partition(tenant.data_partition_id) partition_info = await PartitionService.get_partition(tenant.data_partition_id)
storage_account = await partition_info.get_value(STORAGE_ACCOUNT_NAME) storage_account = partition_info.get_value(STORAGE_ACCOUNT_NAME)
cred = await partition_info.get_value(STORAGE_ACCOUNT_KEY) # use account key cred = partition_info.get_value(STORAGE_ACCOUNT_KEY) # use account key
if cred is None: if cred is None:
cred = self._get_credentials() cred = self._get_credentials()
account_url = self._build_url(storage_account) account_url = self._build_url(storage_account)
......
import json import json
from azure.keyvault.secrets import KeyVaultSecret from mock import patch
from mock import MagicMock
import pytest
from osdu_az.partition.partition_info import PartitionInfo from osdu_az.partition.partition_info import PartitionInfo
partition_service_response = \ partition_service_response = \
...@@ -43,52 +41,40 @@ partition_service_response = \ ...@@ -43,52 +41,40 @@ partition_service_response = \
""" """
@pytest.mark.skip(reason='tmp')
def test_get_value(): def test_get_value():
storage_account_name = 'mocked_storage_account_name_in_key_vault' storage_account_name = 'mocked_storage_account_name_in_key_vault'
kv_mock = MagicMock()
kv_mock.get_secret.return_value = KeyVaultSecret('opendes-storage', storage_account_name)
properties = json.loads(partition_service_response) properties = json.loads(partition_service_response)
partition_info = PartitionInfo(partition_properties=properties) partition_info = PartitionInfo(partition_properties=properties)
partition_info._secret_client = kv_mock
assert (partition_info.get_value('storage-account-name') == storage_account_name) with patch.object(PartitionInfo, '_get_secret', return_value=storage_account_name) as mock:
kv_mock.get_secret.assert_called_with('opendes-storage') assert (partition_info.get_value('storage-account-name') == storage_account_name)
mock.assert_called_with('opendes-storage')
@pytest.mark.skip(reason='tmp')
def test_get_value_invalid_property_name(): def test_get_value_invalid_property_name():
properties = json.loads(partition_service_response) properties = json.loads(partition_service_response)
partition_info = PartitionInfo(partition_properties=properties) partition_info = PartitionInfo(partition_properties=properties)
assert (partition_info.get_value('invalid-partition-service-property') is None) assert (partition_info.get_value('invalid-partition-service-property') is None)
@pytest.mark.skip(reason='tmp')
def test_get_value_non_sensitive(): def test_get_value_non_sensitive():
kv_mock = MagicMock()
properties = json.loads(partition_service_response) properties = json.loads(partition_service_response)
partition_info = PartitionInfo(partition_properties=properties) partition_info = PartitionInfo(partition_properties=properties)
partition_info._secret_client = kv_mock
assert (partition_info.get_value('compliance-ruleset') == 'shared') with patch.object(PartitionInfo, '_get_secret') as mock:
kv_mock.get_secret.assert_not_called() assert (partition_info.get_value('compliance-ruleset') == 'shared')
mock.assert_not_called()
@pytest.mark.skip(reason='tmp')
def test_not_trigger_key_vault_fetch_twice(): def test_not_trigger_key_vault_fetch_twice():
storage_account_name = 'mocked_storage_account_name_in_key_vault' storage_account_name = 'mocked_storage_account_name_in_key_vault'
kv_mock = MagicMock()
kv_mock.get_secret.return_value = KeyVaultSecret('opendes-storage', storage_account_name)
properties = json.loads(partition_service_response) properties = json.loads(partition_service_response)
partition_info = PartitionInfo(partition_properties=properties) partition_info = PartitionInfo(partition_properties=properties)
partition_info._secret_client = kv_mock
partition_info.get_value('storage-account-name') with patch.object(PartitionInfo, '_get_secret', return_value=storage_account_name) as mock:
partition_info.get_value('storage-account-name') partition_info.get_value('storage-account-name')
partition_info.get_value('storage-account-name')
mock.assert_called_once_with('opendes-storage')
kv_mock.get_secret.assert_called_once()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment