Commit 34df8623 authored by Valentin Gauthier's avatar Valentin Gauthier
Browse files

Manifest by reference, first code version

parent 049757ac
Pipeline #98611 passed with stages
in 18 minutes and 3 seconds
# Copyright 2021 Google LLC
# Copyright 2021 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""R3 Validate reference Manifest operator."""
import logging
from airflow.models import BaseOperator, Variable
from osdu_ingestion.libs.context import Context
from osdu_ingestion.libs.refresh_token import AirflowTokenRefresher
from osdu_ingestion.libs.validation.validate_file_source import FileSourceValidator
from osdu_ingestion.libs.validation.validate_referential_integrity import ManifestIntegrity
from osdu_airflow.backward_compatibility.airflow_utils import apply_defaults
from osdu_airflow.operators.mixins.ReceivingContextMixin import ReceivingContextMixin
logger = logging.getLogger()
class EnsureManifestIntegrityOperatorByReference(BaseOperator, ReceivingContextMixin):
"""Operator to validate ref inside manifest R3 and remove invalid entities."""
ui_color = '#dad5ff'
ui_fgcolor = '#000000'
@apply_defaults
def __init__(self,
previous_task_id: str=None,
*args, **kwargs):
"""Init base operator and obtain base urls from Airflow Variables."""
super().__init__(*args, **kwargs)
self.whitelist_ref_patterns = Variable.get('core__config__reference_patterns_whitelist', default_var=None)
self.previous_task_id = previous_task_id
self._show_skipped_ids = Variable.get(
'core__config__show_skipped_ids', default_var=False
)
def execute(self, context: dict):
"""Execute manifest validation then process it.
:param context: Airflow context
:type context: dict
"""
payload_context = Context.populate(context["dag_run"].conf["execution_context"])
token_refresher = AirflowTokenRefresher()
file_source_validator = FileSourceValidator()
manifest_integrity = ManifestIntegrity(
token_refresher,
file_source_validator,
payload_context,
self.whitelist_ref_patterns,
)
execution_context = context["dag_run"].conf["execution_context"]
# manifest_data = self._get_manifest_data(context, execution_context)
manifest_data = self._get_manifest_data_by_reference(context, execution_context)
previously_skipped_entities = self._get_previously_skipped_entities(context)
logger.debug(f"Manifest data: {manifest_data}")
manifest, skipped_ids = manifest_integrity.ensure_integrity(
manifest_data,
previously_skipped_entities
)
logger.debug(f"Valid manifest data: {manifest_data}")
if self._show_skipped_ids:
context["ti"].xcom_push(key="skipped_ids", value=skipped_ids)
# return {"manifest": manifest}
return self._put_manifest_data_by_reference(context, execution_context, manifest)
......@@ -18,6 +18,25 @@ import logging
from osdu_ingestion.libs.types import ManifestType
from osdu_api.model.http_method import HttpMethod ###
from osdu_api.model.dataset.create_dataset_registries_request import CreateDatasetRegistriesRequest ###
from osdu_api.clients.entitlements.entitlements_client import EntitlementsClient ###
from osdu_api.clients.dataset.dataset_dms_client import DatasetDmsClient ###
from osdu_api.clients.dataset.dataset_registry_client import DatasetRegistryClient ###
from osdu_api.model.storage.record import Record ###
from osdu_api.model.storage.record_ancestry import RecordAncestry ###
from osdu_api.model.storage.acl import Acl ###
from osdu_api.model.storage.legal import Legal ###
from osdu_api.configuration.config_manager import DefaultConfigManager ###
import re as re ##
import json ##
logger = logging.getLogger()
class ReceivingContextMixin:
"""Mixin for receiving manifest file from XCOMs in case if current operator not the first in the row"""
......@@ -49,3 +68,164 @@ class ReceivingContextMixin:
if task_skipped_ids:
previously_skipped_ids.extend(task_skipped_ids)
return previously_skipped_ids
def _get_manifest_data_by_reference(self, context: dict, execution_context: dict, use_history:bool=False) -> ManifestType:
"""
[Geosiris Developement]
Get manifest from a datasetService.
If use_history is set to True, the data is taken from the record_id history instead of using last task return value
"""
config_manager = DefaultConfigManager()
data_partition_id = config_manager.get('environment', 'data_partition_id')
# logger.debug(f"#ReceivingContextMixin _get_manifest_data_by_reference starts ")
if use_history:
# record_id_list = execution_context["manifest_ref_ids"]
record_id_list = context["ti"].xcom_pull(task_ids=self.previous_task_id, key="manifest_ref_ids")
record_id = record_id_list[-1] # the last one is the most recent one
elif self.previous_task_id:
record_id = context["ti"].xcom_pull(task_ids=self.previous_task_id, key="return_value")
else:
logger.debug(f"#ReceivingContextMixin _get_manifest_data_by_reference no data found (you should try to use records id history) ")
clientDms = DatasetDmsClient(data_partition_id=data_partition_id)
retrieval = clientDms.get_retrieval_instructions(record_id=record_id).json()
# logger.debug(f"#ReceivingContextMixin retrieval : {retrieval}")
client_reg = DatasetRegistryClient()
manifest_data = client_reg.get_dataset_registry(record_id=record_id)
# logger.debug(f"#ReceivingContextMixin manifest_data : {manifest_data.json()}")
return manifest_data.json()['datasetRegistries'][0]
def _put_manifest_data_by_reference(self, context: dict, execution_context: dict, manifest_data, use_history:bool=False) -> int:
"""
[Geosiris Developement]
Put manifest into a datasetService and get back an access to the content.
If use_history is set to True, the manifest record id is still return but also save into xcom at key "manifest_ref_ids".
The history may also contains all record_ids from previous tasks if they also used the history.
"""
config_manager = DefaultConfigManager()
data_partition_id = config_manager.get('environment', 'data_partition_id')
dataset_registry_url = config_manager.get('environment', 'dataset_registry_url')
match_domain = re.search(r'https?://([\w\.-]+).*', dataset_registry_url)
dataset_registry_url_domain = match_domain.group(1)
# logger.debug(f"##ReceivingContextMixin dataset_registry_url {dataset_registry_url} : \ndomain is {dataset_registry_url_domain}")
client_dms = DatasetDmsClient(data_partition_id=data_partition_id)
storage_instruction = client_dms.get_storage_instructions(kind_sub_type="dataset--File.Generic") # TODO: change to fit for manifest
# logger.debug(f"##ReceivingContextMixin storage_instruction : {storage_instruction.json()}")
signedUrl = ""
try:
signedUrl = storage_instruction.json()['storageLocation']['signedUrl']
except Exception as e:
logger.debug("No 'signed' parameter url found for storage location")
unsignedUrl = ""
try:
unsignedUrl = storage_instruction.json()['storageLocation']['unsignedUrl']
except Exception as e:
logger.debug("No 'unsigned' parameter url found for storage location")
fileSource = ""
try:
fileSource = storage_instruction.json()['storageLocation']['fileSource']
except Exception as e:
logger.debug("No 'filesource' parameter found for storage location")
signedUploadFileName = ""
try:
signedUploadFileName = storage_instruction.json()['storageLocation']['signedUploadFileName']
except Exception as e:
logger.debug("No 'signedUploadFileName' parameter found for storage location")
# Uploading data
# logger.debug(f"##ReceivingContextMixin signedUrl : {signedUrl}")
put_result = client_dms.make_request(method=HttpMethod.PUT, url=signedUrl, data=str(manifest_data))
# logger.debug(f"##ReceivingContextMixin put_result : {put_result}")
########## ACL
ent_client = EntitlementsClient()
ent_response = ent_client.get_groups_for_user()
# logger.debug(f"ent_client : {ent_response.json()}")
# logger.debug(f"ent_client : {ent_response.json()['groups']}")
acl_domain = data_partition_id + "." + dataset_registry_url_domain
data_default_viewers = "data.default.viewers@" + acl_domain
data_default_owners = "data.default.owners@" + acl_domain
viewers_found = False
owners_found = False
# Should we use the ACL from the manifest ? or as follow, the default acl
# search ACL in the entitlement clien response
# here we search a group name that contains "data" (and "viewers" / "owners")
for ent_grp_elt in ent_response.json()["groups"]:
# logger.debug(f"ent_client : {ent_grp_elt}")
if "data" in ent_grp_elt["name"] and "viewers" in ent_grp_elt["name"]:
data_default_viewers = ent_grp_elt["email"]
viewers_found = True
if "data" in ent_grp_elt["name"] and "owner" in ent_grp_elt["name"]:
data_default_owners = ent_grp_elt["email"]
owners_found = True
if viewers_found and owners_found:
break
acl_data = Acl( viewers=[data_default_viewers],
owners=[data_default_owners])
# logger.debug(f"ACL for the record are set to : {acl_data.__dict__}")
# TODO: get the legalTag with a request to the specific service :
# Legal tag may be generated with a request on : "/api/legal/v1/legaltags"
recordList = [
Record( kind = "osdu:wks:dataset--File.Generic:1.0.0",
acl = acl_data,
legal = Legal(legaltags=[data_partition_id + "-demo-legaltag"],
other_relevant_data_countries=["US"],
status="compliant"),
data = {
"DatasetProperties": {
"FileSourceInfo": {
"FileSource": fileSource,
"PreLoadFilePath": unsignedUrl+signedUploadFileName
}
},
"ResourceSecurityClassification": "osdu:reference-data--ResourceSecurityClassification:RESTRICTED:",
"SchemaFormatTypeID": "osdu:reference-data--SchemaFormatType:TabSeparatedColumnarText:"
},
# id: str = None,
version = 1614105463059152,
ancestry=RecordAncestry(parents=[]))
]
client_reg = DatasetRegistryClient()
registered_dataset = client_reg.register_dataset(CreateDatasetRegistriesRequest(dataset_registries=recordList))
# logger.debug(f"##ReceivingContextMixin storage_instruction : {registered_dataset.json()}")
record_id = registered_dataset.json()['datasetRegistries'][0]["id"]
if use_history:
manifest_ref_ids = context["ti"].xcom_pull(task_ids=self.previous_task_id, key="manifest_ref_ids")
if manifest_ref_ids is None:
manifest_ref_ids = []
manifest_ref_ids.append(record_id)
context["ti"].xcom_push(key="manifest_ref_ids", value=manifest_ref_ids)
# logger.debug(f"##ReceivingContextMixin record_id_list : {manifest_ref_ids}")
return record_id
\ No newline at end of file
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
R3 Process Manifest operator.
"""
import logging
from math import ceil
from typing import List, Tuple
from airflow.models import BaseOperator, Variable
from jsonschema import SchemaError
from osdu_ingestion.libs.constants import DATA_TYPES_WITH_SURROGATE_KEYS, SURROGATE_KEYS_PATHS
from osdu_ingestion.libs.context import Context
from osdu_ingestion.libs.exceptions import (EmptyManifestError, GenericManifestSchemaError,
GetSchemaError, NotOSDUSchemaFormatError,
UploadFileError)
from osdu_ingestion.libs.handle_file import FileHandler
from osdu_ingestion.libs.process_manifest_r3 import ManifestProcessor
from osdu_ingestion.libs.processors.single_manifest_processor import SingleManifestProcessor
from osdu_ingestion.libs.refresh_token import AirflowTokenRefresher
from osdu_ingestion.libs.source_file_check import SourceFileChecker
from osdu_ingestion.libs.types import ManifestType
from osdu_ingestion.libs.validation.validate_file_source import FileSourceValidator
from osdu_ingestion.libs.validation.validate_referential_integrity import ManifestIntegrity
from osdu_ingestion.libs.validation.validate_schema import SchemaValidator
from requests import HTTPError
from osdu_airflow.backward_compatibility.airflow_utils import apply_defaults
from osdu_airflow.operators.mixins.ReceivingContextMixin import ReceivingContextMixin
logger = logging.getLogger()
class ProcessManifestOperatorR3ByReference(BaseOperator, ReceivingContextMixin):
"""Operator to process manifest R3."""
ui_color = '#dad5ff'
ui_fgcolor = '#000000'
# TODO: Use the corresponding constant from osdu_api latter.
SAVE_RECORDS_BATCH_SIZE = 500
@apply_defaults
def __init__(self, previous_task_id: str = None, batch_number=3, *args, **kwargs):
"""Init base operator and obtain base urls from Airflow Variables."""
super().__init__(*args, **kwargs)
self.previous_task_id = previous_task_id
self.batch_number = batch_number
self.storage_url = Variable.get('core__service__storage__url')
self.file_service_host = Variable.get('core__service__file__host')
self.batch_count = int(Variable.get("core__ingestion__batch_count", "3"))
self.batch_save_enabled = Variable.get("core__ingestion__batch_save_enabled", default_var=False, deserialize_json=True)
self.batch_save_size = int(Variable.get("core__ingestion__batch_save_size", default_var=self.SAVE_RECORDS_BATCH_SIZE))
self._show_skipped_ids = Variable.get('core__config__show_skipped_ids', default_var=False)
def _get_manifest_files_range(self, manifests: List[dict]) -> Tuple[int, int]:
"""
Get start and end indexes of a manifest files slice to be processed within this task.
:param manifests: A list of manifests
:return: start index, end index
"""
split_size = ceil(len(manifests) / self.batch_count)
slice_start_index = (self.batch_number - 1) * split_size
slice_end_index = self.batch_number * split_size
return slice_start_index, slice_end_index
def _process_manifest(
self,
single_manifest_processor: SingleManifestProcessor,
manifest: ManifestType
) -> Tuple[List[str], List[dict]]:
"""
:param single_manifest_processor: Object to process a single manifest file.
Processing includes validation against schemas, storing records enc.
:param manifest: A single manifest file or a list of them.
:return:
"""
skipped_entities = []
if isinstance(manifest, dict):
record_ids, skipped_entities = single_manifest_processor.process_manifest(
manifest, False)
elif isinstance(manifest, list):
record_ids = []
slice_start_index, slice_end_index = self._get_manifest_files_range(manifest)
logger.debug(f"Start and indexes {slice_start_index}:{slice_end_index}")
for single_manifest in manifest[slice_start_index:slice_end_index]:
logger.debug(f"processing {single_manifest}")
try:
saved_records_ids, not_saved_records = single_manifest_processor.process_manifest(
single_manifest, True
)
record_ids.extend(saved_records_ids)
skipped_entities.extend(not_saved_records)
except (UploadFileError, HTTPError, GetSchemaError, SchemaError,
GenericManifestSchemaError) as e:
logger.error(f"Can't process {single_manifest}")
logger.error(e)
continue
else:
raise NotOSDUSchemaFormatError(
f"Manifest {manifest} must be either not empty 'list' or 'dict'")
return record_ids, skipped_entities
def execute(self, context: dict):
"""Execute manifest validation then process it.
Get a single manifest file or a list of them.
If it is a list, calculate which range (slice) of manifest files must be processed and then
process this range one by one.
:param context: Airflow context
:type context: dict
"""
execution_context = context["dag_run"].conf["execution_context"]
payload_context = Context.populate(execution_context)
token_refresher = AirflowTokenRefresher()
file_handler = FileHandler(self.file_service_host, token_refresher, payload_context)
file_source_validator = FileSourceValidator()
source_file_checker = SourceFileChecker()
referential_integrity_validator = ManifestIntegrity(
token_refresher,
file_source_validator,
payload_context
)
manifest_processor = ManifestProcessor(
file_handler=file_handler,
token_refresher=token_refresher,
context=payload_context,
source_file_checker=source_file_checker,
)
validator = SchemaValidator(
token_refresher,
payload_context,
data_types_with_surrogate_ids=DATA_TYPES_WITH_SURROGATE_KEYS,
surrogate_key_fields_paths=SURROGATE_KEYS_PATHS
)
single_manifest_processor = SingleManifestProcessor(
storage_url=self.storage_url,
payload_context=payload_context,
referential_integrity_validator=referential_integrity_validator,
manifest_processor=manifest_processor,
schema_validator=validator,
token_refresher=token_refresher,
batch_save_enabled=self.batch_save_enabled,
save_records_batch_size=self.batch_save_size
)
manifest_data = self._get_manifest_data_by_reference(context, execution_context)
logger.debug(f"Manifest data: {manifest_data}")
if not manifest_data:
raise EmptyManifestError(
f"Data {context['dag_run'].conf} doesn't contain 'manifest field'")
record_ids, skipped_ids = self._process_manifest(single_manifest_processor, manifest_data)
logger.info(f"Processed ids {record_ids}")
context["ti"].xcom_push(key="record_ids", value=record_ids)
if self._show_skipped_ids:
context["ti"].xcom_push(key="skipped_ids", value=skipped_ids)
# Copyright 2021 Google LLC
# Copyright 2021 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validate Manifest against R3 schemas operator.
"""
import logging
from airflow.models import BaseOperator, Variable
from osdu_ingestion.libs.constants import DATA_TYPES_WITH_SURROGATE_KEYS, SURROGATE_KEYS_PATHS
from osdu_ingestion.libs.context import Context
from osdu_ingestion.libs.exceptions import EmptyManifestError, GenericManifestSchemaError
from osdu_ingestion.libs.refresh_token import AirflowTokenRefresher
from osdu_ingestion.libs.validation.validate_schema import SchemaValidator
from osdu_airflow.backward_compatibility.airflow_utils import apply_defaults
from osdu_airflow.operators.mixins.ReceivingContextMixin import ReceivingContextMixin
logger = logging.getLogger()
class ValidateManifestSchemaOperatorByReference(BaseOperator, ReceivingContextMixin):
"""Operator to validate manifest against definition schemasR3."""
ui_color = '#dad5ff'
ui_fgcolor = '#000000'
@apply_defaults
def __init__(self, previous_task_id: str = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.previous_task_id = previous_task_id
self._show_skipped_ids = Variable.get(
'core__config__show_skipped_ids', default_var=False
)
def execute(self, context: dict):
"""Execute manifest validation then process it.
Get a single manifest file or a list of them.
If it is a list, calculate which range (slice) of manifest files must be processed and then
process this range one by one.
:param context: Airflow context
:type context: dict
"""
logger.debug("Starting Validating manifest")
execution_context = context["dag_run"].conf["execution_context"]
payload_context = Context.populate(execution_context)
token_refresher = AirflowTokenRefresher()
logger.debug(f"DATA_TYPES_WITH_SURROGATE_KEYS: {DATA_TYPES_WITH_SURROGATE_KEYS}")
logger.debug(f"SURROGATE_KEYS_PATHS: {SURROGATE_KEYS_PATHS}")
schema_validator = SchemaValidator(
token_refresher,
payload_context,
surrogate_key_fields_paths=SURROGATE_KEYS_PATHS,
data_types_with_surrogate_ids=DATA_TYPES_WITH_SURROGATE_KEYS
)
manifest_data = self._get_manifest_data(context, execution_context)
# manifest_data = self._get_manifest_data_by_reference(context, execution_context)
# logger.debug(f"Manifest data: {manifest_data}")
if not manifest_data:
raise EmptyManifestError(
f"Data {context['dag_run'].conf} doesn't contain 'manifest field'")
_ = schema_validator.validate_common_schema(manifest_data)
try:
valid_manifest_file, skipped_entities = schema_validator.ensure_manifest_validity(
manifest_data
)
except GenericManifestSchemaError as err:
context["ti"].xcom_push(key="skipped_ids", value=str(err))
raise err
if self._show_skipped_ids:
context["ti"].xcom_push(key="skipped_ids", value=skipped_entities)
return self._put_manifest_data_by_reference(context, execution_context, valid_manifest_file)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment