Skip to content
Snippets Groups Projects
Commit 28ff1705 authored by Siarhei Khaletski (EPAM)'s avatar Siarhei Khaletski (EPAM) :triangular_flag_on_post:
Browse files

Merge branch 'feature/GONRG-1652_Replace_surrogate_key' into 'integration-master'

GONRG-1652 "Feature/ replace surrogate key"

See merge request go3-nrg/platform/data-flow/ingestion/ingestion-dags!43
parents 95e35940 8d646d71
No related branches found
No related tags found
No related merge requests found
Showing
with 719 additions and 48 deletions
......@@ -30,6 +30,7 @@
**/.installed.cfg
**/*.egg
**/MANIFEST
**/venv
# ignore coverage.py
htmlcov/*
......
......@@ -18,3 +18,5 @@
RETRIES = 3
TIMEOUT = 1
WAIT = 10
FIRST_STORED_RECORD_INDEX = 0
......@@ -54,7 +54,6 @@ class FileSourceError(Exception):
class UploadFileError(Exception):
"""Raise when there is an error while uploading a file into OSDU."""
class TokenRefresherNotPresentError(Exception):
"""Raise when token refresher is not present in "refresh_token' decorator."""
pass
......@@ -64,6 +63,12 @@ class NoParentEntitySystemSRNError(Exception):
"""Raise when parent entity doesn't have system-generated SRN."""
pass
class NoParentEntitySystemSRNError(Exception):
"""
Raise when parent entity doesn't have system-generated SRN.
"""
pass
class InvalidFileRecordData(Exception):
"""Raise when file data does not contain mandatory fields."""
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import re
from collections import deque
from typing import Set, Iterator, Iterable
from uuid import uuid4
import tenacity
import requests
import toposort
from libs.constants import RETRIES, TIMEOUT
from libs.context import Context
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from libs.traverse_manifest import ManifestEntity
logger = logging.getLogger()
class EntityNode(object):
"""
This class represents entities and their links to parent and child ones.
"""
__slots__ = ["srn", "system_srn", "entity_info", "children", "parents", "unprocessed"]
SRN_REGEX = re.compile(
r"(?<=\")surrogate-key:[\w\-\.\d]+(?=\")|(?<=\")[\w\-\.]+:[\w\-\.]+--[\w\-\.]+:[\w\-\.\d]+(?=\")")
def __init__(self, srn, entity_info: ManifestEntity):
self.srn = srn
self.entity_info = entity_info
self.system_srn = None
self.children = set()
self.parents = set()
self.unprocessed = False
def __repr__(self):
return f"SRN: {self.srn}"
@property
def content(self) -> dict:
return self.entity_info.entity
@content.setter
def content(self, value: dict):
self.entity_info.entity = value
def add_child(self, child_node: "EntityNode"):
self.children.add(child_node)
def add_parent(self, parent_node: "EntityNode"):
self.parents.add(parent_node)
def get_parent_srns(self) -> Set[str]:
"""
Get list of parents' srns.
"""
entity_content = json.dumps(self.content, separators=(",", ":"))
parent_srns = set(self.SRN_REGEX.findall(entity_content))
parent_srns.discard(self.srn)
return parent_srns
def replace_parents_surrogate_srns(self):
"""
Replace surrogate parents' keys with system-generated ones in entity.
"""
if not self.parents:
return
content = json.dumps(self.content)
for parent in self.parents:
if parent.system_srn:
content = content.replace(parent.srn, parent.system_srn)
self.content = json.loads(content)
class ManifestAnalyzer(HeadersMixin):
"""
This class is for creating a queue for ingesting set of data, each piece of data can depend on
another one, so we must prioritize the order of ingesting. The idea is to create a dependency
graph and traverse it to get the right order of ingestion.
The flow of prioritizing entities could be described as:
1. Fill graph's nodes with entities (self._fill_srn_node_table())
2. Create links between nodes (self._fill_nodes_edges())
3. Mark unprocessed nodes if they are orphaned or dependant on orphaned nodes (self._find_unprocessed_nodes())
4. Return prioritized queue for ingesting (self.entity_queue())
"""
def __init__(
self,
entities: Iterable[ManifestEntity],
storage_service_url: str,
token_refresher: TokenRefresher,
context: Context
):
super().__init__(context)
self.entities = entities
self.storage_service_url = storage_service_url
self.token_refresher = token_refresher
self.srn_node_table = dict()
self.processed_entities = []
# used as a root for all orphan entities
empty_entity_info = ManifestEntity("", {})
self.unprocessed_entities_parent = EntityNode(srn=str(uuid4()),
entity_info=empty_entity_info)
self.unprocessed_entities = set()
self._fill_srn_node_table()
self._fill_nodes_parents()
self._find_unprocessed_nodes()
def _create_entity_node(self, entity: ManifestEntity):
srn = entity.entity.get("id", f"surrogate-key:{str(uuid4())}")
self.srn_node_table[srn] = EntityNode(srn, entity)
def _create_work_product_entities_nodes(self, work_product: dict):
for part_name, work_product_part in work_product.items():
if part_name == "WorkProduct":
self._create_entity_node(work_product_part)
else:
for p in work_product_part:
self._create_entity_node(p)
def _fill_srn_node_table(self):
for entity in self.entities:
self._create_entity_node(entity)
def _fill_nodes_parents(self):
"""
Find parents in every entity.
"""
for entity_node in self.srn_node_table.values():
self._set_entity_parents(entity_node)
@tenacity.retry(
wait=tenacity.wait_fixed(TIMEOUT),
stop=tenacity.stop_after_attempt(RETRIES),
reraise=True
)
@refresh_token()
def _get_storage_record_request(self, headers: dict, srn: str) -> requests.Response:
logger.debug(f"Searching for {srn}")
return requests.get(f"{self.storage_service_url}/{srn}", headers=headers)
def _is_in_storage(self, parent_srn: str) -> bool:
try:
self._get_storage_record_request(self.request_headers, parent_srn)
return True
except requests.HTTPError:
return False
def _set_entity_parents(self, entity: EntityNode):
"""
Find all references parent in entity's content.
If a parent is not presented in manifest, mark this entity as unprocessed.
If a parent is not in manifest and it is already in Storage, we do nothing with it.
"""
parent_srns = entity.get_parent_srns()
for parent_srn in parent_srns:
if self.srn_node_table.get(parent_srn):
parent_node = self.srn_node_table[parent_srn]
parent_node.add_child(entity)
entity.add_parent(parent_node)
elif self._is_in_storage(parent_srn):
continue
else: # if entity refers to srn not presenting in manifest
self.unprocessed_entities_parent.add_child(entity)
logger.info(f"'{entity}' is orphaned. Missing parent '{parent_srn}'")
def _find_unprocessed_nodes(self):
"""
Traverse entities dependant on orphaned or invalid ones.
Add them to set of unprocessed nodes to exclude them from ingestion queue.
"""
queue = deque()
queue.append(self.unprocessed_entities_parent)
while queue:
node = queue.popleft()
self.unprocessed_entities.add(node)
logger.debug(f"Node {node} added to unprocessed.")
for child in node.children:
if not child.unprocessed:
child.unprocessed = True
queue.append(child)
self.unprocessed_entities.discard(self.unprocessed_entities_parent)
def entity_queue(self) -> Iterator[EntityNode]:
"""
Create a queue, where a child entity goes after all its parents.
If an entity is marked as unprocessed, then skip it.
"""
entity_graph = {entity: entity.parents for entity in self.srn_node_table.values()}
logger.debug(f"Entity graph {entity_graph}.")
entity_queue = toposort.toposort_flatten(entity_graph, sort=False)
for entity in entity_queue:
if entity not in self.unprocessed_entities:
self.processed_entities.append(entity)
yield entity
for entity in self.unprocessed_entities:
entity.replace_parents_surrogate_srns()
logger.debug(f"Visited entities {self.processed_entities}")
logger.debug(f"Unprocessed entities {self.unprocessed_entities}")
def add_unprocessed_entity(self, entity: EntityNode):
"""
Use if there some problems with ingesting entity.
Mark it and its dependants as unprocessed.
"""
self.unprocessed_entities_parent.add_child(entity)
self._find_unprocessed_nodes()
......@@ -23,13 +23,14 @@ from typing import List, Tuple
import requests
import tenacity
from libs.constants import RETRIES, WAIT
from libs.context import Context
from libs.constants import RETRIES, WAIT
from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError
from libs.handle_file import FileHandler
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from libs.source_file_check import SourceFileChecker
from libs.handle_file import FileHandler
from libs.traverse_manifest import ManifestEntity
RETRY_SETTINGS = {
"stop": tenacity.stop_after_attempt(RETRIES),
......@@ -55,7 +56,6 @@ class ManifestProcessor(HeadersMixin):
def __init__(
self,
storage_url: str,
manifest_records: List[dict],
file_handler: FileHandler,
source_file_checker: SourceFileChecker,
token_refresher: TokenRefresher,
......@@ -63,24 +63,21 @@ class ManifestProcessor(HeadersMixin):
):
"""Manifest processor.
:param storage_url: The OSDU Storage base url
:type storage_url: str
:param dagrun_conf: The conf obtained from dagrun
:type dagrun_conf: dict
:param file_handler: An instance of a file handler
:type file_handler: FileHandler
:param source_file_checker: An instance of file checker
:type source_file_checker: SourceFileChecker
:param token_refresher: An instance of token refresher
:type token_refresher: TokenRefresher
:param storage_url: The OSDU Storage base url
:type storage_url: str
:param context: The tenant context
:type context: Context
:param token_refresher: An instance of token refresher
:type token_refresher: TokenRefresher
"""
super().__init__(context)
self.file_handler = file_handler
self.source_file_checker = source_file_checker
self.storage_url = storage_url
self.manifest_records = manifest_records
self.context = context
self.token_refresher = token_refresher
......@@ -115,6 +112,7 @@ class ManifestProcessor(HeadersMixin):
:rtype: dict
"""
record = copy.deepcopy(self.RECORD_TEMPLATE)
manifest = self._delete_surrogate_key(manifest)
if manifest.get("id"):
record["id"] = manifest["id"]
record["kind"] = manifest.pop("kind")
......@@ -123,6 +121,11 @@ class ManifestProcessor(HeadersMixin):
record["data"] = manifest.pop("data")
return record
def _delete_surrogate_key(self, entity: dict) -> dict:
if "surrogate-key:" in entity.get("id", ""):
del entity["id"]
return entity
def _populate_file_storage_record(self, manifest: dict) -> dict:
"""Create a record from file manifest to store it via File service.
......@@ -132,6 +135,7 @@ class ManifestProcessor(HeadersMixin):
:rtype: dict
"""
record = copy.deepcopy(self.RECORD_TEMPLATE)
manifest = self._delete_surrogate_key(manifest)
if manifest.get("id"):
record["id"] = manifest["id"]
record["kind"] = manifest.pop("kind")
......@@ -197,7 +201,8 @@ class ManifestProcessor(HeadersMixin):
if not file_record["data"]["DatasetProperties"]["FileSourceInfo"]["FileSource"]:
file_record = self.upload_source_file(file_record)
else:
file_source = file_record["data"]["DatasetProperties"]["FileSourceInfo"]["FileSource"]
file_source = file_record["data"]["DatasetProperties"]["FileSourceInfo"][
"FileSource"]
file_location = self.file_handler.get_file_staging_location(file_source)
self.source_file_checker.does_file_exist(file_location)
......@@ -205,21 +210,23 @@ class ManifestProcessor(HeadersMixin):
records.append(record)
return records
def process_manifest(self) -> List[str]:
def process_manifest_records(self, manifest_records: List[ManifestEntity]) -> List[str]:
"""Process manifests and save them into Storage service.
:manifest_records: List of ManifestEntities to be ingested.
:raises EmptyManifestError: When manifest is empty
:return: List of ids of saved records
:rtype: List[str]
"""
record_ids = []
populated_manifest_records = []
if not self.manifest_records:
if not manifest_records:
raise EmptyManifestError
for manifest_record in self.manifest_records:
populated_manifest_records.append(self.populate_manifest_storage_record(manifest_record.get("entity")))
for manifest_record in manifest_records:
populated_manifest_records.append(
self.populate_manifest_storage_record(manifest_record.entity))
save_manifests_response = self.save_record_to_storage(
self.request_headers, request_data=populated_manifest_records)
self.request_headers, request_data=populated_manifest_records)
record_ids.extend(save_manifests_response.json()["recordIds"])
return record_ids
......@@ -17,12 +17,23 @@ import copy
import logging
from typing import List
import dataclasses
from libs.exceptions import EmptyManifestError
logger = logging.getLogger()
@dataclasses.dataclass()
class ManifestEntity:
schema: str
entity: dict
def __eq__(self, other: "ManifestEntity"):
return self.entity == other.entity\
and self.schema == other.schema
class ManifestTraversal(object):
"""Class to traverse manifest and extract all manifest records"""
......@@ -40,11 +51,7 @@ class ManifestTraversal(object):
:return:
"""
extracted_schema = schema.split("/")[-1]
logger.debug(f"Extracted schema kind: {extracted_schema}")
return {
"schema": extracted_schema,
"entity": entity
}
return ManifestEntity(entity=entity, schema=extracted_schema)
def _traverse_list(self, manifest_entities: List[dict], property_name: str, manifest_schema_part: dict):
"""
......@@ -57,7 +64,7 @@ class ManifestTraversal(object):
manifest_schema_part[property_name]["items"]["$ref"]))
return entities
def traverse_manifest(self) -> List[dict]:
def traverse_manifest(self) -> List[ManifestEntity]:
"""
Traverse manifest structure and return the list of manifest records.
......
......@@ -16,7 +16,6 @@
"""Provides SchemaValidator."""
import copy
import json
import logging
from typing import Union, Any, List
......@@ -26,6 +25,7 @@ import tenacity
from jsonschema import exceptions
from libs.context import Context
from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError
from libs.traverse_manifest import ManifestEntity
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
......@@ -74,20 +74,16 @@ class SchemaValidator(HeadersMixin):
"""Class to validate schema of Manifests."""
def __init__(
self, schema_service: str,
self,
schema_service: str,
token_refresher: TokenRefresher,
context: Context
):
"""Init SchemaValidator.
:param schema_service: The base OSDU Schema service url
:type schema_service: str
:param dagrun_conf: The airflow dagrun.conf
:type dagrun_conf: dict
:param token_refresher: An instance of token refresher
:type token_refresher: TokenRefresher
:param context: The tenant context
:type context: Context
"""
super().__init__(context)
self.schema_service = schema_service
......@@ -212,7 +208,7 @@ class SchemaValidator(HeadersMixin):
self._validate_against_schema(schema, manifest)
return schema
def validate_manifest(self, manifest_records: List[dict]) -> List[dict]:
def validate_manifest(self, manifest_records: List[ManifestEntity]) -> List[ManifestEntity]:
"""
Validate manifest's entities one-by-one. Return list of
:param manifest_records: List of manifest's records
......@@ -222,7 +218,7 @@ class SchemaValidator(HeadersMixin):
if not manifest_records:
raise EmptyManifestError
for manifest_record in manifest_records:
manifest = manifest_record.get("entity")
manifest = manifest_record.entity
if isinstance(manifest, dict) and manifest.get("kind"):
validation_result = self._validate_entity(manifest)
if validation_result:
......
......@@ -13,12 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""R2 Process Manifest operator."""
"""
R3 Process Manifest operator.
"""
import logging
from airflow.utils import apply_defaults
from airflow.models import BaseOperator, Variable
from libs.constants import FIRST_STORED_RECORD_INDEX
from libs.context import Context
from libs.manifest_analyzer import ManifestAnalyzer
from libs.source_file_check import SourceFileChecker
from libs.handle_file import FileHandler
from libs.refresh_token import AirflowTokenRefresher
......@@ -42,6 +48,24 @@ class ProcessManifestOperatorR3(BaseOperator):
self.schema_service_url = Variable.get('schema_service_url')
self.storage_url = Variable.get('storage_url')
self.file_service_url = Variable.get('file_service_url')
self.processed_entities = []
def _process_records(self, manifest_analyzer: ManifestAnalyzer,
manifest_processor: ManifestProcessor) -> str:
for entity in manifest_analyzer.entity_queue():
try:
logger.debug(f"Process entity {entity}")
entity.replace_parents_surrogate_srns()
record_id = manifest_processor.process_manifest_records(
[entity.entity_info]
)[FIRST_STORED_RECORD_INDEX]
entity.system_srn = record_id
yield record_id
self.processed_entities.append(f"{entity.srn}:{entity.system_srn}")
except Exception as e:
logger.warning(f"Can't process entity {entity}")
logger.error(e)
manifest_analyzer.add_unprocessed_entity(entity)
def execute(self, context: dict):
"""Execute manifest validation then process it.
......@@ -56,31 +80,43 @@ class ProcessManifestOperatorR3(BaseOperator):
:param context: Airflow context
:type context: dict
"""
record_ids = []
execution_context = context["dag_run"].conf["execution_context"]
payload_context = Context.populate(execution_context)
token_refresher = AirflowTokenRefresher()
file_handler = FileHandler(self.file_service_url, token_refresher, payload_context)
source_file_checker = SourceFileChecker()
manifest_processor = ManifestProcessor(
storage_url=self.storage_url,
file_handler=file_handler,
token_refresher=token_refresher,
context=payload_context,
source_file_checker=source_file_checker,
)
validator = SchemaValidator(
self.schema_service_url,
token_refresher,
payload_context
)
manifest_schema = validator.validate_common_schema(execution_context)
traversal = ManifestTraversal(execution_context, manifest_schema)
manifest_entities = traversal.traverse_manifest()
logger.debug(f"entities count: {len(manifest_entities)}")
valid_manifest_entities = validator.validate_manifest(manifest_entities)
logger.debug(f"valid entities count: {len(valid_manifest_entities)}")
manifest_processor = ManifestProcessor(
self.storage_url,
manifest_analyzer = ManifestAnalyzer(
valid_manifest_entities,
file_handler,
source_file_checker,
self.storage_url,
token_refresher,
payload_context,
payload_context
)
record_ids = manifest_processor.process_manifest()
for record_id in self._process_records(manifest_analyzer, manifest_processor):
record_ids.append(record_id)
logger.info(f"Surrogate-key:system-generated-id list {self.processed_entities}")
logger.info(f"Processed ids {record_ids}")
context["ti"].xcom_push(key="record_ids", value=record_ids)
[
{
"id": "surrogate-key:file-1",
"kind": "osdu:wks:dataset--File.Generic:1.0.0",
"acl": {
"owners": [],
"viewers": []
},
"legal": {
"legaltags": [],
"otherRelevantDataCountries": []
},
"data": {
"ResourceSecurityClassification": "osdu:reference-data--ResourceSecurityClassification:RESTRICTED:",
"SchemaFormatTypeID": "osdu:reference-data--SchemaFormatType:TabSeparatedColumnarText:",
"DatasetProperties": {
"FileSourceInfo": {
"FileSource": "",
"PreloadFilePath": "s3://osdu-seismic-test-data/r1/data/provided/markers/7587.csv"
}
}
}
},
{
"id": "surrogate-key:wpc-1",
"kind": "osdu:wks:work-product-component--WellboreMarker:1.0.0",
"acl": {
"owners": [],
"viewers": []
},
"legal": {
"legaltags": [],
"otherRelevantDataCountries": []
},
"data": {
"ResourceSecurityClassification": "osdu:reference-data--ResourceSecurityClassification:RESTRICTED:",
"Name": "7587.csv",
"Description": "Wellbore Marker",
"Datasets": [
"surrogate-key:file-1"
],
"WellboreID": "osdu:master-data--Wellbore:7587:",
"Markers": [
{
"MarkerName": "North Sea Supergroup",
"MarkerMeasuredDepth": 0.0
},
{
"MarkerName": "Ommelanden Formation",
"MarkerMeasuredDepth": 1555.0
},
{
"MarkerName": "Texel Marlstone Member",
"MarkerMeasuredDepth": 2512.5
},
{
"MarkerName": "Upper Holland Marl Member",
"MarkerMeasuredDepth": 2606.0
},
{
"MarkerName": "Middle Holland Claystone Member",
"MarkerMeasuredDepth": 2723.0
},
{
"MarkerName": "Vlieland Claystone Formation",
"MarkerMeasuredDepth": 2758.0
},
{
"MarkerName": "Lower Volpriehausen Sandstone Member",
"MarkerMeasuredDepth": 2977.5
},
{
"MarkerName": "Rogenstein Member",
"MarkerMeasuredDepth": 3018.0
},
{
"MarkerName": "FAULT",
"MarkerMeasuredDepth": 3043.0
},
{
"MarkerName": "Upper Zechstein salt",
"MarkerMeasuredDepth": 3043.0
},
{
"MarkerName": "FAULT",
"MarkerMeasuredDepth": 3544.0
},
{
"MarkerName": "Z3 Carbonate Member",
"MarkerMeasuredDepth": 3544.0
},
{
"MarkerName": "Z3 Main Anhydrite Member",
"MarkerMeasuredDepth": 3587.0
},
{
"MarkerName": "FAULT",
"MarkerMeasuredDepth": 3622.0
},
{
"MarkerName": "Z3 Salt Member",
"MarkerMeasuredDepth": 3622.0
},
{
"MarkerName": "Z3 Main Anhydrite Member",
"MarkerMeasuredDepth": 3666.5
},
{
"MarkerName": "Z3 Carbonate Member",
"MarkerMeasuredDepth": 3688.0
},
{
"MarkerName": "Z2 Salt Member",
"MarkerMeasuredDepth": 3709.0
},
{
"MarkerName": "Z2 Basal Anhydrite Member",
"MarkerMeasuredDepth": 3985.0
},
{
"MarkerName": "Z2 Carbonate Member",
"MarkerMeasuredDepth": 3996.0
},
{
"MarkerName": "Z1 (Werra) Formation",
"MarkerMeasuredDepth": 4022.5
},
{
"MarkerName": "Ten Boer Member",
"MarkerMeasuredDepth": 4070.0
},
{
"MarkerName": "Upper Slochteren Member",
"MarkerMeasuredDepth": 4128.5
},
{
"MarkerName": "Ameland Member",
"MarkerMeasuredDepth": 4231.0
},
{
"MarkerName": "Lower Slochteren Member",
"MarkerMeasuredDepth": 4283.5
}
]
}
},
{
"kind": "osdu:wks:work-product--WorkProduct:1.0.0",
"acl": {
"owners": [],
"viewers": []
},
"legal": {
"legaltags": [],
"otherRelevantDataCountries": []
},
"data": {
"ResourceSecurityClassification": "osdu:reference-data--ResourceSecurityClassification:RESTRICTED:",
"Name": "7587.csv",
"Description": "Wellbore Marker",
"Components": [
"surrogate-key:wpc-1"
]
}
}
]
......@@ -39,3 +39,6 @@ TRAVERSAL_MANIFEST_EMPTY_PATH = f"{DATA_PATH_PREFIX}/invalid/TraversalEmptyManif
SEARCH_VALID_RESPONSE_PATH = f"{DATA_PATH_PREFIX}/other/SearchResponseValid.json"
SEARCH_INVALID_RESPONSE_PATH = f"{DATA_PATH_PREFIX}/other/SearchResponseInvalid.json"
BATCH_MANIFEST_WELLBORE = f"{DATA_PATH_PREFIX}/batchManifest/Wellbore.0.3.0.json"
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import logging
import sys
from functools import partial
import pytest
from file_paths import BATCH_MANIFEST_WELLBORE
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins")
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
from libs.manifest_analyzer import ManifestAnalyzer, EntityNode
from libs.traverse_manifest import ManifestEntity
from libs.refresh_token import AirflowTokenRefresher
from libs.context import Context
logger = logging.getLogger()
TEST_FAKE_DATA = [
{
"id": "1",
"parents": [],
},
{
"id": "2",
"parents": ["1"],
},
{
"id": "3",
"parents": ["1"],
},
{
"id": "4",
"parents": ["2"],
},
{
"id": "5",
"parents": ["1", "3"],
},
{
"id": "7",
"parents": ["1", "6"],
},
{
"id": "9",
"parents": ["7"]
}
]
TEST_FAKE_DATA = [ManifestEntity(entity=i, schema="") for i in TEST_FAKE_DATA]
class TestManifestAnalyzer(object):
@pytest.fixture()
def manifest_analyzer(self):
with open(BATCH_MANIFEST_WELLBORE) as f:
data = json.load(f)
context = Context(data_partition_id="test", app_key="test")
token_refresher = AirflowTokenRefresher()
return ManifestAnalyzer(data, "http://test.com", token_refresher, context)
@pytest.fixture()
def fake_data_manifest_analyzer(self, monkeypatch, data, is_in_storage: bool = False):
context = Context(data_partition_id="test", app_key="test")
token_refresher = AirflowTokenRefresher()
monkeypatch.setattr(EntityNode, "get_parent_srns", self.mock_get_parent_srns)
monkeypatch.setattr(ManifestAnalyzer, "_is_in_storage", lambda self, srn: is_in_storage)
manifest_analyzer = ManifestAnalyzer(data, "http://test", token_refresher, context)
return manifest_analyzer
def process_entity(self, entity: EntityNode) -> str:
return f"system_srn: {entity.srn}"
def index_in_queue_by_srn(
self,
manifest_analyzer: ManifestAnalyzer,
queue: list,
srn: str
):
entity_node = manifest_analyzer.srn_node_table[srn]
return queue.index(entity_node)
@staticmethod
def mock_get_parent_srns(obj: EntityNode):
parent_srns = set(obj.content.get("parents", []))
return parent_srns
@pytest.mark.parametrize(
"data",
[
pytest.param(TEST_FAKE_DATA, id="Fake data")
]
)
def test_queue_order(
self,
monkeypatch,
fake_data_manifest_analyzer: ManifestAnalyzer,
data: dict
):
"""
Here we use array with simple objects where it's immediately seen who depends on whom.
Check if queue return parents, then and only then their children.
Check if there is no orphaned and their children in the queue (SRN 7 and SRN 9).
"""
queue = list(fake_data_manifest_analyzer.entity_queue())
index_in_queue = partial(self.index_in_queue_by_srn, fake_data_manifest_analyzer, queue)
# check if child goes after all its parents in queue.
assert index_in_queue("5") > index_in_queue("1") \
and index_in_queue("5") > index_in_queue("3"), \
"SRN 5 must follow parents: SRN 1 and 3"
# check if orphans and their dependants are not in ingestion queue.
for unprocessed_srn in ("7", "9"):
unprocessed_entity = fake_data_manifest_analyzer.srn_node_table[unprocessed_srn]
assert unprocessed_entity not in queue, \
f"{unprocessed_entity} expected not to be in queue: {queue}"
@pytest.mark.parametrize(
"data",
[
pytest.param(TEST_FAKE_DATA, id="Fake data")
]
)
def test_add_new_unporcessed(
self,
monkeypatch,
fake_data_manifest_analyzer: ManifestAnalyzer,
data: dict
):
"""
Here we use array with simple objects where it's immediately seen who depends on whom.
Imagine we can't process entity (e.g. Storage service can't save this entity).
Then we must add this entity to unprocessed ones and traverse all the children of
this entity marking them as unprocessed.
They must disappear from the ingestion queue.
"""
queue = fake_data_manifest_analyzer.entity_queue()
unprocessed_entity = fake_data_manifest_analyzer.srn_node_table["3"]
expected_unprocessed_entities = {"7", "9", "3", "5"}
fake_data_manifest_analyzer.add_unprocessed_entity(unprocessed_entity)
for entity in queue:
assert entity not in expected_unprocessed_entities, \
f"{entity} must be excluded from queue."
@pytest.mark.parametrize(
"data, is_in_storage",
[
pytest.param(TEST_FAKE_DATA, True, id="Fake data")
]
)
def test_all_missed_parents_are_in_storage(
self,
monkeypatch,
data: dict,
is_in_storage: bool
):
context = Context(data_partition_id="test", app_key="test")
token_refresher = AirflowTokenRefresher()
monkeypatch.setattr(EntityNode, "get_parent_srns", self.mock_get_parent_srns)
monkeypatch.setattr(ManifestAnalyzer, "_is_in_storage", lambda self, srn: True)
manifest_analyzer = ManifestAnalyzer(data, "http://test", token_refresher, context)
queue = list(manifest_analyzer.entity_queue())
assert not manifest_analyzer.unprocessed_entities, \
"If absent parents are in storage, there are no orphaned child then."
def test_real_data(self):
with open(BATCH_MANIFEST_WELLBORE) as f:
data = json.load(f)
data = [ManifestEntity(entity=e, schema="") for e in data]
context = Context(data_partition_id="test", app_key="test")
token_refresher = AirflowTokenRefresher()
manifest_analyzer = ManifestAnalyzer(data, "http://test.com", token_refresher, context)
for entity in manifest_analyzer.entity_queue():
entity.replace_parents_surrogate_srns()
entity.system_srn = self.process_entity(entity)
logger.info(f"Processed entity {json.dumps(entity.content, indent=2)}")
logger.info("\n")
logger.info(f"Processed entities: {manifest_analyzer.processed_entities}")
logger.info(f"Unprocessed entities {manifest_analyzer.unprocessed_entities}")
......@@ -26,8 +26,8 @@ from libs.context import Context
from libs.handle_file import FileHandler
from libs.source_file_check import SourceFileChecker
from libs.refresh_token import AirflowTokenRefresher
from libs.traverse_manifest import ManifestEntity
from libs.exceptions import EmptyManifestError
from deepdiff import DeepDiff
import pytest
import requests
......@@ -77,8 +77,14 @@ class TestManifestProcessor:
monkeypatch.setattr(requests, "put", mockresponse)
@pytest.fixture()
def manifest_records(self, traversal_manifest_file: str) -> list:
with open(traversal_manifest_file) as f:
manifest_file = json.load(f)
return manifest_file
@pytest.fixture(autouse=True)
def manifest_processor(self, monkeypatch, conf_path: str, traversal_manifest_file: str):
def manifest_processor(self, monkeypatch, traversal_manifest_file, conf_path: str):
with open(conf_path) as f:
conf = json.load(f)
with open(traversal_manifest_file) as f:
......@@ -90,7 +96,6 @@ class TestManifestProcessor:
source_file_checker = SourceFileChecker()
manifest_processor = process_manifest_r3.ManifestProcessor(
storage_url="",
manifest_records=manifest_records,
token_refresher=token_refresher,
context=context,
file_handler=file_handler,
......@@ -132,6 +137,7 @@ class TestManifestProcessor:
self,
monkeypatch,
manifest_processor: process_manifest_r3.ManifestProcessor,
manifest_records,
mock_records_list: list,
traversal_manifest_file: str,
conf_path: str,
......@@ -150,6 +156,7 @@ class TestManifestProcessor:
def test_save_record_invalid_storage_response_value(
self,
monkeypatch,
manifest_records,
manifest_processor: process_manifest_r3.ManifestProcessor,
traversal_manifest_file: str,
conf_path: str
......@@ -169,6 +176,7 @@ class TestManifestProcessor:
def test_save_record_storage_response_http_error(
self,
monkeypatch,
manifest_records,
manifest_processor: process_manifest_r3.ManifestProcessor,
traversal_manifest_file: str,
conf_path: str
......@@ -191,12 +199,14 @@ class TestManifestProcessor:
def test_process_manifest_valid(
self,
monkeypatch,
manifest_records,
manifest_processor: process_manifest_r3.ManifestProcessor,
traversal_manifest_file: str,
conf_path: str
):
self.monkeypatch_storage_response(monkeypatch)
manifest_processor.process_manifest()
manifest_file = [ManifestEntity(entity=e["entity"], schema=e["schema"]) for e in manifest_records]
manifest_processor.process_manifest_records(manifest_file)
@pytest.mark.parametrize(
"conf_path,traversal_manifest_file",
......@@ -207,13 +217,14 @@ class TestManifestProcessor:
def test_process_empty_manifest(
self,
monkeypatch,
manifest_records,
manifest_processor: process_manifest_r3.ManifestProcessor,
traversal_manifest_file: str,
conf_path: str
):
self.monkeypatch_storage_response(monkeypatch)
with pytest.raises(EmptyManifestError):
manifest_processor.process_manifest()
manifest_processor.process_manifest_records(manifest_records)
@pytest.mark.parametrize(
"conf_path,expected_kind_name,traversal_manifest_file",
......@@ -225,10 +236,11 @@ class TestManifestProcessor:
self,
monkeypatch,
manifest_processor: process_manifest_r3.ManifestProcessor,
manifest_records: list,
conf_path: str,
traversal_manifest_file: str,
expected_kind_name: str
):
for manifest_part in manifest_processor.manifest_records:
for manifest_part in manifest_records:
kind = manifest_part["entity"]["kind"]
assert expected_kind_name == manifest_processor._get_kind_name(kind)
......@@ -25,7 +25,7 @@ from file_paths import MANIFEST_WELLBORE_VALID_PATH, TRAVERSAL_WELLBORE_VALID_PA
MANIFEST_SEISMIC_TRACE_DATA_VALID_PATH, TRAVERSAL_SEISMIC_TRACE_DATA_VALID_PATH, MANIFEST_EMPTY_PATH, \
TRAVERSAL_MANIFEST_EMPTY_PATH, MANIFEST_GENERIC_SCHEMA_PATH
from libs.exceptions import EmptyManifestError
from libs.traverse_manifest import ManifestTraversal
from libs.traverse_manifest import ManifestTraversal, ManifestEntity
class TestManifestTraversal:
......@@ -60,6 +60,7 @@ class TestManifestTraversal:
manifest_schema_file: str, traversal_manifest_file: str):
with open(traversal_manifest_file) as f:
traversal_manifest = json.load(f)
traversal_manifest = [ManifestEntity(**e) for e in traversal_manifest]
manifest_records = manifest_traversal.traverse_manifest()
assert manifest_records == traversal_manifest
......
......@@ -73,6 +73,7 @@ class TestOperators(object):
with open(MANIFEST_GENERIC_SCHEMA_PATH) as f:
manifest_schema = json.load(f)
return manifest_schema
monkeypatch.setattr(SchemaValidator, "get_schema", _get_common_schema)
monkeypatch.setattr(SchemaValidator, "validate_manifest", lambda obj, entities: entities)
monkeypatch.setattr(ManifestProcessor, "save_record_to_storage",
......
......@@ -44,6 +44,7 @@ from file_paths import (
)
from mock_responses import MockSchemaResponse
from libs.context import Context
from libs.traverse_manifest import ManifestEntity
from libs.refresh_token import AirflowTokenRefresher
from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError
import pytest
......@@ -111,6 +112,8 @@ class TestSchemaValidator:
monkeypatch.setattr(schema_validator, "get_schema", self.mock_get_schema)
with open(traversal_manifest_file_path) as f:
manifest_file = json.load(f)
manifest_file = [ManifestEntity(entity=e["entity"], schema=e["schema"]) for e in manifest_file]
validated_records = schema_validator.validate_manifest(manifest_file)
assert len(manifest_file) == len(validated_records)
......@@ -149,6 +152,7 @@ class TestSchemaValidator:
schema_file: str):
with open(traversal_manifest_file) as f:
manifest_file = json.load(f)
manifest_file = [ManifestEntity(entity=e["entity"], schema=e["schema"]) for e in manifest_file]
with pytest.raises(NotOSDUSchemaFormatError):
schema_validator.validate_manifest(manifest_file)
......
......@@ -17,6 +17,7 @@ pip install --upgrade google-api-python-client
pip install dataclasses
pip install jsonschema
pip install google
pip install toposort
pip install google-cloud-storage
pip install deepdiff
pip install azure-identity
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment