Commit ed1560a1 authored by Spencer Sutton's avatar Spencer Sutton
Browse files

Merge branch 'master' of...

Merge branch 'master' of community.opengroup.org:osdu/platform/system/sdks/common-python-sdk into aws-update
parents ad9ff99f ae14d37d
Pipeline #67804 passed with stages
in 2 minutes and 52 seconds
......@@ -4,6 +4,9 @@ default:
variables:
OSDU_API_LIBS_DIR: $CI_BUILDS_DIR
CLOUD_PROVIDER: provider_test
BUILD_TAG: $CI_COMMIT_TAG
BUILD_COMMIT_SHORT_SHA: $CI_COMMIT_SHORT_SHA
BUILD_ID: $CI_PIPELINE_IID
stages:
- test
......
......@@ -68,7 +68,7 @@ python setup.py install
```
Example import after installing:
`from osdu_api.storage.record_client import RecordClient`
`from osdu_api.clients.storage.record_client import RecordClient`
## Installation from Package Registry
......
......@@ -32,6 +32,7 @@ SURROGATE_KEYS_PATHS = [
]
SEARCH_ID_BATCH_SIZE = 25
SAVE_RECORDS_BATCH_SIZE = 500
DATA_SECTION = "Data"
......
......@@ -34,6 +34,14 @@ class Context:
:rtype: Context
"""
ctx_payload = ctx.pop('Payload')
try:
data_partition_id = ctx_payload['data-partition-id']
except KeyError:
data_partition_id = ctx['dataPartitionId'] # to support some DAGs payload interface
ctx_obj = cls(app_key=ctx_payload['AppKey'],
data_partition_id=ctx_payload['data-partition-id'])
data_partition_id=data_partition_id)
return ctx_obj
......@@ -15,6 +15,8 @@
"""Exceptions module."""
from typing import List
from osdu_api.libs.utils import create_skipped_entity_info
......@@ -106,3 +108,10 @@ class ProcessRecordError(BaseEntityValidationError):
"""
Raise when a record is unprocessed
"""
class ProcessRecordBatchError(BaseEntityValidationError):
def __init__(self, entities: List[dict], reason: str):
self.skipped_entities = [
create_skipped_entity_info(entity, reason) for entity in entities
]
......@@ -23,7 +23,7 @@ from uuid import uuid4
import toposort
from osdu_api.libs.linearize_manifest import ManifestEntity
from osdu_api.libs.utils import remove_trailing_colon
from osdu_api.libs.utils import is_surrogate_key, remove_trailing_colon
logger = logging.getLogger()
......@@ -157,7 +157,7 @@ class EntityNode:
content = json.dumps(self.data)
for parent in self.parents:
if parent.system_srn:
if "surrogate-key" in parent.srn:
if is_surrogate_key(parent.srn):
# ':' at the end is for showing that it is reference if parent srn was surrogate-key.
content = content.replace(parent.srn, f"{parent.system_srn}:")
else:
......@@ -292,6 +292,20 @@ class ManifestAnalyzer:
for entity in self._invalid_entities_nodes:
entity.replace_parents_surrogate_srns()
def entity_generation_queue(self) -> Iterator[Set[EntityNode]]:
"""
Yield set of not dependant on each other entities (generation).
Generations of parents are followed by generations of children.
"""
entity_graph = {entity: entity.parents for entity in self.srn_node_table.values()}
logger.debug(f"Entity graph {entity_graph}.")
toposorted_entities = toposort.toposort(entity_graph)
for entity_set in toposorted_entities:
valid_entities = {entity for entity in entity_set
if entity not in self._invalid_entities_nodes and not entity.is_external_srn}
yield valid_entities
def add_invalid_node(self, entity: EntityNode):
"""
Use if there some problems with ingesting or finding entity.
......
......@@ -30,6 +30,7 @@ from osdu_api.libs.handle_file import FileHandler
from osdu_api.libs.mixins import HeadersMixin
from osdu_api.libs.source_file_check import SourceFileChecker
from osdu_api.libs.linearize_manifest import ManifestEntity
from osdu_api.libs.utils import is_surrogate_key
RETRY_SETTINGS = {
"stop": tenacity.stop_after_attempt(RETRIES),
......@@ -96,7 +97,7 @@ class ManifestProcessor(HeadersMixin):
return file_record
def _delete_surrogate_key(self, entity: dict) -> dict:
if "surrogate-key:" in entity.get("id", ""):
if is_surrogate_key(entity.get("id", "")):
del entity["id"]
return entity
......
......@@ -19,17 +19,18 @@ R3 Process Single Manifest helper.
"""
import logging
from typing import List, Tuple
from typing import Iterator, List, Set, Tuple
from osdu_api.libs.constants import FIRST_STORED_RECORD_INDEX
from osdu_api.libs.constants import FIRST_STORED_RECORD_INDEX, SAVE_RECORDS_BATCH_SIZE
from osdu_api.libs.context import Context
from osdu_api.libs.exceptions import ProcessRecordError
from osdu_api.libs.exceptions import ProcessRecordError, ProcessRecordBatchError
from osdu_api.libs.manifest_analyzer import ManifestAnalyzer, EntityNode
from osdu_api.libs.process_manifest_r3 import ManifestProcessor
from osdu_api.libs.refresh_token import TokenRefresher
from osdu_api.libs.linearize_manifest import ManifestLinearizer
from osdu_api.libs.validation.validate_referential_integrity import ManifestIntegrity
from osdu_api.libs.validation.validate_schema import SchemaValidator
from osdu_api.libs.utils import create_skipped_entity_info, is_surrogate_key, split_into_batches
logger = logging.getLogger()
......@@ -44,6 +45,8 @@ class SingleManifestProcessor:
manifest_processor: ManifestProcessor,
schema_validator: SchemaValidator,
token_refresher: TokenRefresher,
batch_save_enabled: bool = False,
save_records_batch_size: int = SAVE_RECORDS_BATCH_SIZE
):
"""Init SingleManifestProcessor."""
super().__init__()
......@@ -53,13 +56,17 @@ class SingleManifestProcessor:
self.manifest_processor = manifest_processor
self.schema_validator = schema_validator
self.token_refresher = token_refresher
self.batch_save_enabled = batch_save_enabled
self.save_records_batch_size = save_records_batch_size
def _process_record(self, entity_node: EntityNode) -> str:
def _process_single_entity_node(self, manifest_analyzer: ManifestAnalyzer, entity_node: EntityNode):
"""
Attempt to store a single record in Storage service.
Process single entity node. Try to save the entity's data in Storage service.
Replace surrogate keys of the entity and its children with the system generated one.
:param record: A record to be stored
:return: Record id
:param manifest_analyzer: Object with proper queue of entities.
:param entity_node: Entity node to be processed.
:return: Saved record id.
"""
try:
logger.debug(f"Process entity {entity_node}")
......@@ -68,11 +75,44 @@ class SingleManifestProcessor:
[entity_node.entity_info]
)[FIRST_STORED_RECORD_INDEX]
entity_node.system_srn = record_id
except Exception as e:
raise ProcessRecordError(entity_node.entity_info.entity_data, f"{e}"[:128])
except Exception as error:
logger.warning(f"Can't process entity {entity_node}")
manifest_analyzer.add_invalid_node(entity_node)
raise ProcessRecordError(entity_node.entity_info.entity_data, f"{error}"[:128])
return record_id
def _process_records(self, manifest_analyzer: ManifestAnalyzer) -> Tuple[List[str], List[dict]]:
def _process_entity_nodes_batch(
self,
manifest_analyzer: ManifestAnalyzer,
entity_node_batch: List[EntityNode]
) -> Tuple[List[str], List[dict]]:
"""
Try to process batch of EntityNodes by saving their data in Storage Service.
At the current implementation of Storage Service
the whole batch isn't saved in Storage Service if one or more entities are invalid.
:param manifest_analyzer: Object with proper queue of entities.
:param :
:return: List of saved record ids.
"""
record_ids = []
try:
manifest_entities = [entity_node.entity_info for entity_node in entity_node_batch]
record_ids.extend(self.manifest_processor.process_manifest_records(manifest_entities))
except Exception as e:
# TODO: Fix skipping saving the whole batch in Storage Service if some records in this batch are invalid.
logger.warning(f"Can't process batch {entity_node_batch}. {str(e)[:128]}")
for entity_node in entity_node_batch:
manifest_analyzer.add_invalid_node(entity_node)
raise ProcessRecordBatchError([node.data for node in entity_node_batch], f"{e}"[:128])
return record_ids
def _process_records_by_one(self, manifest_analyzer: ManifestAnalyzer) -> Tuple[List[str], List[dict]]:
"""
Process each entity from entity queue created according to child-parent relationships
between entities.
......@@ -83,15 +123,86 @@ class SingleManifestProcessor:
"""
record_ids = []
skipped_ids = []
for entity in manifest_analyzer.entity_queue():
for entity_node in manifest_analyzer.entity_queue():
try:
record_ids.append(self._process_record(entity))
record_ids.append(self._process_single_entity_node(manifest_analyzer, entity_node))
except ProcessRecordError as error:
logger.warning(f"Can't process entity {entity}")
manifest_analyzer.add_invalid_node(entity)
skipped_ids.append(error.skipped_entity)
return record_ids, skipped_ids
def _split_ids_by_type(
self,
entity_node_batch: List[EntityNode]
) -> Tuple[List[EntityNode], List[EntityNode]]:
"""
Split entity node batch into two lists with surrogate keys, and real ids.
:param entity_node_batch: Batch of entity node
:return: Two lists of surrogate-key entities and not-surrogate-key entities.
"""
surrogate_key_id_nodes = []
not_surrugate_key_id_nodes = []
for entity_node in entity_node_batch:
entity_node.replace_parents_surrogate_srns()
if is_surrogate_key(entity_node.data.get("id", "")):
surrogate_key_id_nodes.append(entity_node)
else:
not_surrugate_key_id_nodes.append(entity_node)
return not_surrugate_key_id_nodes, surrogate_key_id_nodes
def _save_entities_generation(
self,
manifest_analyzer: ManifestAnalyzer,
entity_nodes_generation: Set[EntityNode]
) :
"""
Save set of independent from each other entities in Storage Service by chunks.
:param manifest_analyzer: Object with proper queue of entities.
:param entity_nodes_generation: Set of independent from each other entity nodes.
:return: List of saved ids and skipped ones.
"""
record_ids = []
skipped_ids = []
for entity_node_batch in split_into_batches(entity_nodes_generation, self.save_records_batch_size):
# surrogate-key entities and real-id entities must be treated in different ways.
not_surrogate_key_nodes, surrogate_key_nodes = self._split_ids_by_type(entity_node_batch)
if not_surrogate_key_nodes:
try:
record_ids.extend(
self._process_entity_nodes_batch(manifest_analyzer, not_surrogate_key_nodes)
)
except ProcessRecordBatchError as error:
skipped_ids.extend(error.skipped_entities)
for entity_node in surrogate_key_nodes:
try:
record_ids.append(self._process_single_entity_node(manifest_analyzer, entity_node))
except ProcessRecordError as error:
skipped_ids.append(error.skipped_entity)
return record_ids, skipped_ids
def _process_records_by_batches(self, manifest_analyzer: ManifestAnalyzer) -> Tuple[List[str], List[dict]]:
"""
Save batches of entities in Storage Service.
:param manifest_analyzer: Object with proper queue of entities
:return: List of saved ids and skipped ones.
"""
record_ids = []
skipped_ids = []
for entity_nodes_generation in manifest_analyzer.entity_generation_queue():
logger.info(f"Generation: {entity_nodes_generation}")
generation_record_ids, generation_skipped_ids = self._save_entities_generation(manifest_analyzer, entity_nodes_generation)
record_ids.extend(generation_record_ids)
skipped_ids.extend(generation_skipped_ids)
return record_ids, skipped_ids
def process_manifest(self, manifest: dict, with_validation: bool) -> Tuple[
List[str], List[dict]]:
"""Execute manifest validation then process it.
......@@ -130,8 +241,16 @@ class SingleManifestProcessor:
{entity["id"] for entity in skipped_ids if entity.get("id")}
)
record_ids, not_valid_ids = self._process_records(manifest_analyzer)
if self.batch_save_enabled:
record_ids, not_valid_ids = self._process_records_by_batches(manifest_analyzer)
else:
record_ids, not_valid_ids = self._process_records_by_one(manifest_analyzer)
skipped_ids.extend(not_valid_ids)
skipped_ids.extend(
[create_skipped_entity_info(node.data, f"Missing parents {node.invalid_parents}")
for node in manifest_analyzer.invalid_entity_nodes if node.invalid_parents]
)
logger.info(f"Processed ids {record_ids}")
......
......@@ -15,11 +15,15 @@
"""Util functions to work with OSDU Manifests."""
from typing import Any
from itertools import islice
from typing import Any, Generator, Iterable, List, TypeVar
import dataclasses
BatchElement = TypeVar("BatchElement")
@dataclasses.dataclass
class EntityId:
id: str
......@@ -74,3 +78,45 @@ def create_skipped_entity_info(entity: Any, reason: str) -> dict:
"reason": reason[:128]
}
return skipped_entity_info
def split_into_batches(
element_sequence: Iterable[BatchElement],
batch_size: int
) -> Generator[List[BatchElement], None, None]:
"""
Split external ids into batches of the same size
:param element_seqeuence:
:param batch_size:
:return:
"""
if not isinstance(element_sequence, Iterable):
raise TypeError(
f"Element sequence '{element_sequence}' is '{type(element_sequence)}'. "
"It must be either 'list' or 'tuple'."
)
element_sequence = iter(element_sequence)
while True:
batch = list(islice(element_sequence, batch_size))
if not batch:
return
yield batch
def is_surrogate_key(entity_id: str):
"""
Check if the entity's id is surrogate.
:param entity_ids: Entitiy ID
:return: bool
"""
if "surrogate-key:" in entity_id:
return True
else:
return False
......@@ -22,8 +22,8 @@ from osdu_api.libs.constants import DATA_SECTION, DATASETS_SECTION, WORK_PRODUCT
from osdu_api.libs.exceptions import EmptyManifestError, ValidationIntegrityError
from osdu_api.libs.search_record_ids import ExtendedSearchId
from osdu_api.libs.validation.validate_file_source import FileSourceValidator
from osdu_api.libs.utils import create_skipped_entity_info, split_id, remove_trailing_colon, \
EntityId
from osdu_api.libs.utils import create_skipped_entity_info, split_id, split_into_batches, \
remove_trailing_colon, EntityId
from osdu_api.libs.linearize_manifest import ManifestLinearizer
from osdu_api.libs.manifest_analyzer import ManifestAnalyzer
......@@ -85,7 +85,7 @@ class ManifestIntegrity:
:return: Set of not found references via Search.
"""
missing_ids = set()
for ids_batch in self._external_ids_batch(external_references):
for ids_batch in split_into_batches(external_references, self.search_id_batch_size):
# Search can't work with ids with versions. So get only ids without versions.
external_references_without_version = [e.id for e in ids_batch]
......@@ -98,25 +98,6 @@ class ManifestIntegrity:
return {missing_id.srn for missing_id in missing_ids}
def _external_ids_batch(
self,
external_ids: List[EntityId]
) -> Generator[Set[EntityId], None, tuple]:
"""
Split external ids into batches of the same size
:param external_ids:
:return:
"""
if external_ids:
chunk_number = len(external_ids) // self.search_id_batch_size + 1
for n in range(chunk_number):
offset = n * self.search_id_batch_size
ids_slice = external_ids[offset:offset + self.search_id_batch_size]
yield set(ids_slice)
else:
return ()
def _ensure_wpc_artefacts_integrity(self, wpc: dict):
artefacts = wpc["data"].get("Artefacts")
if not artefacts:
......
......@@ -21,3 +21,4 @@ class HttpMethod(enum.Enum):
POST = enum.auto()
PUT = enum.auto()
DELETE = enum.auto()
PATCH = enum.auto()
......@@ -22,6 +22,9 @@ import msal
import os
from azure.keyvault import secrets
from azure import identity
import requests
import json
logger = logging.getLogger(__name__)
RETRIES = 3
......@@ -37,6 +40,7 @@ class AzureCredentials(BaseCredentials):
self._client_secret = None
self._tenant_id = None
self._resource_id = None
self._azure_paas_podidentity_isEnabled= os.getenv("AIRFLOW_VAR_AZURE_ENABLE_MSI")
def _populate_ad_credentials(self) -> None:
uri = os.getenv("AIRFLOW_VAR_KEYVAULT_URI")
......@@ -48,34 +52,48 @@ class AzureCredentials(BaseCredentials):
self._resource_id = client.get_secret("aad-client-id").value
def _generate_token(self) -> str:
if self._client_id is None:
self._populate_ad_credentials()
if self._tenant_id is None:
logger.error('TenantId is not set properly')
raise ValueError("TenantId is not set properly")
if self._resource_id is None:
logger.error('ResourceId is not set properly')
raise ValueError("ResourceId is not set properly")
if self._client_id is None:
logger.error('Please pass client Id to generate token')
raise ValueError("Please pass client Id to generate token")
if self._client_secret is None:
logger.error('Please pass client secret to generate token')
raise ValueError("Please pass client secret to generate token")
try:
authority_host_uri = 'https://login.microsoftonline.com'
authority_uri = authority_host_uri + '/' + self._tenant_id
scopes = [self._resource_id + '/.default']
app = msal.ConfidentialClientApplication(client_id = self._client_id,
authority = authority_uri,
client_credential = self._client_secret)
result = app.acquire_token_for_client(scopes=scopes)
return result.get('access_token')
except Exception as e:
logger.error(e)
raise e
if self._azure_paas_podidentity_isEnabled == "true":
try:
print("MSI Token generation")
headers = {
'Metadata': 'true'
}
url = 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F'
response = requests.request("GET", url, headers=headers)
data_msi = json.loads(response.text)
token = data_msi["access_token"]
return token
except Exception as e:
logger.error(e)
raise e
else:
if self._client_id is None:
self._populate_ad_credentials()
if self._tenant_id is None:
logger.error('TenantId is not set properly')
raise ValueError("TenantId is not set properly")
if self._resource_id is None:
logger.error('ResourceId is not set properly')
raise ValueError("ResourceId is not set properly")
if self._client_id is None:
logger.error('Please pass client Id to generate token')
raise ValueError("Please pass client Id to generate token")
if self._client_secret is None:
logger.error('Please pass client secret to generate token')
raise ValueError("Please pass client secret to generate token")
try:
authority_host_uri = 'https://login.microsoftonline.com'
authority_uri = authority_host_uri + '/' + self._tenant_id
scopes = [self._resource_id + '/.default']
app = msal.ConfidentialClientApplication(client_id = self._client_id,
authority = authority_uri,
client_credential = self._client_secret)
result = app.acquire_token_for_client(scopes=scopes)
return result.get('access_token')
except Exception as e:
logger.error(e)
raise e
@retry(stop=stop_after_attempt(RETRIES))
def refresh_token(self) -> str:
......
......@@ -20,6 +20,7 @@ DATA_PATH_PREFIX = f"{os.path.dirname(__file__)}/data"
MANIFEST_REFERENCE_PATTERNS_WHITELIST = f"{DATA_PATH_PREFIX}/reference_patterns_whitelist.txt"
MANIFEST_GENERIC_SCHEMA_PATH = f"{DATA_PATH_PREFIX}/manifests/schema_Manifest.1.0.0.json"
MANIFEST_BATCH_SAVE_PATH = f"{DATA_PATH_PREFIX}/manifests/batch_save_Manifest.json"
MANIFEST_NEW_GENERIC_SCHEMA_PATH = f"{DATA_PATH_PREFIX}/manifests/new_schema_Manifest.1.0.0.json"
MANIFEST_GENERIC_PATH = f"{DATA_PATH_PREFIX}/manifests/Manifest.1.0.0.json"
......
......@@ -296,3 +296,41 @@ class TestManifestAnalyzer(object):
)
wpc_parents = {p.srn for p in manifest_analyzer.srn_node_table[tested_wpc["id"]].parents}
assert wpc_parents == set(expected_result)
@pytest.mark.parametrize(
"manifest,expected_generations",
[
pytest.param(
[
{
"id": "surrogate-key:wpc",
"ref": "surrogate-key:wpc2",
"ref2": "surrogate-key:wpc3",
},
{
"id": "surrogate-key:wpc2",
"ref": "surrogate-key:wpc4",
},
{
"id": "surrogate-key:wpc3",
"ref": "surrogate-key:wpc4"
},
{
"id": "surrogate-key:wpc4",
}
],
[{"surrogate-key:wpc4"}, {"surrogate-key:wpc3", "surrogate-key:wpc2"}, {"surrogate-key:wpc"}],
id="Surrogate key"
)
]
)
def test_generation_queue(self, manifest, expected_generations):
data = [ManifestEntity(entity_data=e, manifest_path="") for e in manifest]
manifest_analyzer = ManifestAnalyzer(
data
)
result = []
for generation in manifest_analyzer.entity_generation_queue():
result.append({e.data.get("id") for e in generation})
assert result == expected_generations
\ No newline at end of file