diff --git a/osdu_airflow/operators/ensure_manifest_integrity_by_reference.py b/osdu_airflow/operators/ensure_manifest_integrity_by_reference.py index 5b7362c482bb252df8c052cb2eba4285bb1e381d..9cc8a8b183b2764ef53f5c8a31d5aee67658174d 100644 --- a/osdu_airflow/operators/ensure_manifest_integrity_by_reference.py +++ b/osdu_airflow/operators/ensure_manifest_integrity_by_reference.py @@ -17,6 +17,12 @@ import logging from airflow.models import BaseOperator, Variable +from osdu_api.clients.dataset.dataset_dms_client import DatasetDmsClient +from osdu_api.clients.dataset.dataset_registry_client import DatasetRegistryClient +from osdu_api.clients.entitlements.entitlements_client import EntitlementsClient +from osdu_api.clients.search.search_client import SearchClient +from osdu_api.configuration.config_manager import DefaultConfigManager +from osdu_ingestion.libs import search_client from osdu_ingestion.libs.context import Context from osdu_ingestion.libs.refresh_token import AirflowTokenRefresher from osdu_ingestion.libs.validation.validate_file_source import FileSourceValidator @@ -53,27 +59,55 @@ class EnsureManifestIntegrityOperatorByReference(BaseOperator, ReceivingContextM :param context: Airflow context :type context: dict """ - payload_context = Context.populate( - context["dag_run"].conf["execution_context"]) + execution_context = context["dag_run"].conf["execution_context"] + payload_context = Context.populate(context["dag_run"].conf["execution_context"]) token_refresher = AirflowTokenRefresher() file_source_validator = FileSourceValidator() + config_manager = DefaultConfigManager() + + search_client = SearchClient( + search_url= Variable.get("core__service__search__url", default_var=None), + token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id + ) manifest_integrity = ManifestIntegrity( - token_refresher, - file_source_validator, - payload_context, - self.whitelist_ref_patterns, + search_client=search_client, + file_source_validator=file_source_validator, + context=payload_context, + whitelist_ref_patterns=self.whitelist_ref_patterns, ) - execution_context = context["dag_run"].conf["execution_context"] + dataset_dms_client = DatasetDmsClient( + dataset_url=Variable.get("core__service__dataset__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) - manifest_data = self._get_manifest_data_by_reference(context=context, + dataset_reg_client = DatasetRegistryClient( + dataset_url=Variable.get("core__service__dataset__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + + entitlements_client = EntitlementsClient( + entitlements_url=Variable.get("core__service__entitlements__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + + manifest_data = self._get_manifest_data_by_reference(context=context, execution_context=execution_context, use_history=False, - config_manager=None, - data_partition_id=None, - token_refresher=token_refresher, + dataset_dms_client=dataset_dms_client, logger=logger) + previously_skipped_entities = self._get_previously_skipped_entities( context) logger.debug(f"Manifest data: {manifest_data}") @@ -90,7 +124,8 @@ class EnsureManifestIntegrityOperatorByReference(BaseOperator, ReceivingContextM execution_context=execution_context, manifest=manifest, use_history=False, - config_manager=None, - data_partition_id=None, - token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id, + dataset_dms_client=dataset_dms_client, + dataset_reg_client=dataset_reg_client, + entitlements_client=entitlements_client, logger=logger) diff --git a/osdu_airflow/operators/mixins/ReceivingContextMixin.py b/osdu_airflow/operators/mixins/ReceivingContextMixin.py index 2c3333afc728f2349c4c55ded813efd818b75a82..4b197aba3f53f1d6421c5ae8f0a8db34a6de735c 100644 --- a/osdu_airflow/operators/mixins/ReceivingContextMixin.py +++ b/osdu_airflow/operators/mixins/ReceivingContextMixin.py @@ -17,11 +17,11 @@ import json import logging import re as re +from airflow.models import Variable from osdu_api.auth.authorization import TokenRefresher from osdu_api.clients.dataset.dataset_dms_client import DatasetDmsClient from osdu_api.clients.dataset.dataset_registry_client import DatasetRegistryClient from osdu_api.clients.entitlements.entitlements_client import EntitlementsClient -from osdu_api.configuration.base_config_manager import BaseConfigManager from osdu_api.configuration.config_manager import DefaultConfigManager from osdu_api.model.dataset.create_dataset_registries_request import CreateDatasetRegistriesRequest from osdu_api.model.http_method import HttpMethod @@ -68,25 +68,20 @@ class ReceivingContextMixin: previously_skipped_ids.extend(task_skipped_ids) return previously_skipped_ids - def _get_manifest_data_by_reference(self, context: dict, execution_context: dict, use_history:bool=False, - config_manager: BaseConfigManager = None, - data_partition_id = None, - token_refresher: TokenRefresher = None, + def _get_manifest_data_by_reference(self, context: dict, execution_context: dict, + dataset_dms_client: DatasetDmsClient, use_history:bool=False, logger = None) -> ManifestType: """ [Geosiris Developement] Get manifest from a datasetService. If use_history is set to True, the data is taken from the record_id history instead of using last task return value """ - if config_manager is None: - config_manager = DefaultConfigManager() - - if data_partition_id is None: - data_partition_id = config_manager.get('environment', 'data_partition_id') - if logger is None: logger = logging.getLogger() + use_new_dataset_service_endpoints = Variable.get("use_new_dataset_service_endpoints", default_var=True, deserialize_json=True) + + record_id = "" if use_history: record_id_list = context["ti"].xcom_pull(task_ids=self.previous_task_id, key="manifest_ref_ids") record_id = record_id_list[-1] # the last one is the most recent one @@ -98,24 +93,29 @@ class ReceivingContextMixin: logger.debug(f"#ReceivingContextMixin try to retrieve data from id : {record_id}.") - client_dms = DatasetDmsClient( config_manager=config_manager, - data_partition_id=data_partition_id, - token_refresher=token_refresher, - logger=logger) - retrieval = client_dms.get_retrieval_instructions(record_id=record_id) + if use_new_dataset_service_endpoints: + logger.debug(f"#ReceivingContextMixin your are using NEW dataset-service endpoints." + + "If you want to use old dataset-service endpoints, set variable 'use_new_dataset_service_endpoints' to false" + + "in Airflow global variables") + retrieval = dataset_dms_client.retrieval_instructions(record_id=record_id) + else: + logger.debug(f"#ReceivingContextMixin your are using OLD dataset-service endpoints") + retrieval = dataset_dms_client.get_retrieval_instructions(record_id=record_id) retrieval_content_url = retrieval.json()["delivery"][0]["retrievalProperties"]["signedUrl"] - manifest_data = client_dms.make_request(method=HttpMethod.GET, url=retrieval_content_url).json() + manifest_data = dataset_dms_client.make_request(method=HttpMethod.GET, url=retrieval_content_url).json() if isinstance(manifest_data, str): return json.loads(manifest_data) else: return manifest_data - def _put_manifest_data_by_reference(self, context: dict, execution_context: dict, manifest, use_history:bool=False, - config_manager: BaseConfigManager = None, - data_partition_id = None, - token_refresher: TokenRefresher = None, + def _put_manifest_data_by_reference(self, context: dict, execution_context: dict, manifest, + data_partition_id: str, + dataset_dms_client: DatasetDmsClient, + dataset_reg_client: DatasetRegistryClient, + entitlements_client: EntitlementsClient=None, + use_history:bool=False, logger = None) -> str: """ [Geosiris Developement] @@ -142,9 +142,8 @@ class ReceivingContextMixin: if acl_data is None: logger.debug(f"Getting default value for acl, because not found in manifest. {type(manifest_dict)}, \n{manifest_dict}") acl_data = self._get_default_acl(execution_context=execution_context, - config_manager=config_manager, + entitlements_client=entitlements_client, data_partition_id=data_partition_id, - token_refresher=token_refresher, logger=logger) #### END ACL @@ -161,17 +160,15 @@ class ReceivingContextMixin: if legal_tags is None: logger.debug(f"Getting default value for legal, because not found in manifest. {type(manifest_dict)}, \n{manifest_dict}") - legal_tags = self._get_default_legaltags(execution_context=execution_context, - config_manager=config_manager, - data_partition_id=data_partition_id) + legal_tags = self._get_default_legaltags(execution_context=execution_context, data_partition_id=data_partition_id) #### END Legal tags record_id = self._put_file_on_dataset_service( file_content=manifest, # manifest_dict acl_data=acl_data, legal_tags=legal_tags, - config_manager=config_manager, + dataset_reg_client=dataset_reg_client, + dataset_dms_client=dataset_dms_client, data_partition_id=data_partition_id, - token_refresher=token_refresher, logger=logger) if use_history: @@ -187,32 +184,31 @@ class ReceivingContextMixin: def _put_file_on_dataset_service(self, file_content, acl_data: Acl, legal_tags: Legal, - config_manager: BaseConfigManager = None, - data_partition_id = None, - token_refresher: TokenRefresher = None, + dataset_reg_client: DatasetRegistryClient, + dataset_dms_client: DatasetDmsClient, + data_partition_id: str, logger = None) -> str: """ [Geosiris Developement] Store a file on the dataset-service """ - if config_manager is None: - config_manager = DefaultConfigManager() - - if data_partition_id is None: - data_partition_id = config_manager.get('environment', 'data_partition_id') - if logger is None: logger = logging.getLogger() - dataset_registry_url = config_manager.get('environment', 'dataset_registry_url') + use_new_dataset_service_endpoints = Variable.get("use_new_dataset_service_endpoints", default_var=True, deserialize_json=True) + + dataset_registry_url = Variable.get("core__service__dataset__url", default_var=None) match_domain = re.search(r'https?://([\w\.-]+).*', dataset_registry_url) dataset_registry_url_domain = match_domain.group(1) - client_dms = DatasetDmsClient( config_manager=config_manager, - data_partition_id=data_partition_id, - token_refresher=token_refresher, - logger=logger) - storage_instruction = client_dms.get_storage_instructions(kind_sub_type="dataset--File.Generic") + if use_new_dataset_service_endpoints: + logger.debug(f"#ReceivingContextMixin your are using NEW dataset-service endpoints." + + "If you want to use old dataset-service endpoints, set variable 'use_new_dataset_service_endpoints' to false" + + "in Airflow global variables") + storage_instruction = dataset_dms_client.storage_instructions(kind_sub_type="dataset--File.Generic") + else: + logger.debug(f"#ReceivingContextMixin your are using OLD dataset-service endpoints") + storage_instruction = dataset_dms_client.get_storage_instructions(kind_sub_type="dataset--File.Generic") storage_location = storage_instruction.json()['storageLocation'] @@ -240,9 +236,8 @@ class ReceivingContextMixin: except KeyError as e: logger.debug("No 'signedUploadFileName' parameter found for storage location") - #### Uploading data - put_result = client_dms.make_request(method=HttpMethod.PUT, url=signed_url, data=file_content) + put_result = dataset_dms_client.make_request(method=HttpMethod.PUT, url=signed_url, data=file_content) record_list = [ Record( kind = f"{data_partition_id}:wks:dataset--File.Generic:1.0.0", @@ -252,7 +247,7 @@ class ReceivingContextMixin: "DatasetProperties": { "FileSourceInfo": { "FileSource": file_source, - "PreloadFilePath": f"{unsigned_url}{signed_upload_file_name}" + "PreLoadFilePath": f"{unsigned_url}{signed_upload_file_name}" } }, "ResourceSecurityClassification": f"{data_partition_id}:reference-data--ResourceSecurityClassification:RESTRICTED:", @@ -262,42 +257,28 @@ class ReceivingContextMixin: ancestry=RecordAncestry(parents=[])) ] - client_reg = DatasetRegistryClient( config_manager=config_manager, - data_partition_id=data_partition_id, - token_refresher=token_refresher, - logger=logger) + - registered_dataset = client_reg.register_dataset(CreateDatasetRegistriesRequest(dataset_registries=record_list)) + registered_dataset = dataset_reg_client.register_dataset(CreateDatasetRegistriesRequest(dataset_registries=record_list)) return registered_dataset.json()['datasetRegistries'][0]["id"] def _get_default_acl(self, execution_context: dict, - config_manager: BaseConfigManager = None, - data_partition_id = None, - token_refresher: TokenRefresher = None, - logger = None): + entitlements_client: EntitlementsClient, + data_partition_id = None, + logger = None): if "acl" in execution_context: # try to take the value from the context return execution_context['acl'] else: - if config_manager is None: - config_manager = DefaultConfigManager() - - if data_partition_id is None: - data_partition_id = config_manager.get('environment', 'data_partition_id') - if logger is None: logger = logging.getLogger() - dataset_registry_url = config_manager.get('environment', 'dataset_registry_url') + dataset_registry_url = Variable.get("core__service__dataset__url", default_var=None) match_domain = re.search(r'https?://([\w\.-]+).*', dataset_registry_url) dataset_registry_url_domain = match_domain.group(1) - - ent_client = EntitlementsClient(config_manager=config_manager, - data_partition_id=data_partition_id, - token_refresher=token_refresher, - logger=logger) - ent_response = ent_client.get_groups_for_user() + + ent_response = entitlements_client.get_groups_for_user() acl_domain = data_partition_id + "." + dataset_registry_url_domain @@ -325,15 +306,10 @@ class ReceivingContextMixin: def _get_default_legaltags(self, execution_context: dict, - data_partition_id: str=None, - config_manager: BaseConfigManager = None): + data_partition_id: str): if "legal" in execution_context: # try to take the value from the context return execution_context['legal'] else: - if config_manager is None: - config_manager = DefaultConfigManager() - if data_partition_id is None: - data_partition_id = config_manager.get('environment', 'data_partition_id') return Legal(legaltags=[data_partition_id + "-demo-legaltag"], other_relevant_data_countries=["US"], status="compliant") diff --git a/osdu_airflow/operators/process_manifest_r3_by_reference.py b/osdu_airflow/operators/process_manifest_r3_by_reference.py index 39db0bcb5b9d7294f031973ab41c03637e975657..7f3545574ce47410af4107406e94d2056e9a9213 100644 --- a/osdu_airflow/operators/process_manifest_r3_by_reference.py +++ b/osdu_airflow/operators/process_manifest_r3_by_reference.py @@ -24,6 +24,11 @@ from typing import List, Tuple from airflow.models import BaseOperator, Variable from jsonschema import SchemaError +from osdu_api.clients.dataset.dataset_dms_client import DatasetDmsClient +from osdu_api.clients.schema.schema_client import SchemaClient +from osdu_api.clients.search.search_client import SearchClient +from osdu_api.clients.storage.record_client import RecordClient +from osdu_api.configuration.config_manager import DefaultConfigManager from osdu_ingestion.libs.constants import DATA_TYPES_WITH_SURROGATE_KEYS, SURROGATE_KEYS_PATHS from osdu_ingestion.libs.context import Context from osdu_ingestion.libs.exceptions import (EmptyManifestError, GenericManifestSchemaError, @@ -136,31 +141,52 @@ class ProcessManifestOperatorR3ByReference(BaseOperator, ReceivingContextMixin): file_source_validator = FileSourceValidator() source_file_checker = SourceFileChecker() + search_client = SearchClient( + search_url= Variable.get("core__service__search__url", default_var=None), + token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id + ) + record_client = RecordClient( + storage_url=Variable.get("core__service__storage__url", default_var=None), + token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id + ) + schema_client = SchemaClient( + schema_url=Variable.get("core__service__schema__url", default_var=None), + token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id + ) + + dataset_dms_client = DatasetDmsClient( + dataset_url=Variable.get("core__service__dataset__url", default_var=None), + config_manager=DefaultConfigManager(), + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + referential_integrity_validator = ManifestIntegrity( - token_refresher, + search_client, file_source_validator, payload_context ) manifest_processor = ManifestProcessor( + record_client=record_client, file_handler=file_handler, - token_refresher=token_refresher, context=payload_context, source_file_checker=source_file_checker, ) validator = SchemaValidator( - token_refresher, - payload_context, + schema_client=schema_client, + context=payload_context, data_types_with_surrogate_ids=DATA_TYPES_WITH_SURROGATE_KEYS, surrogate_key_fields_paths=SURROGATE_KEYS_PATHS ) single_manifest_processor = SingleManifestProcessor( - storage_url=self.storage_url, - payload_context=payload_context, referential_integrity_validator=referential_integrity_validator, manifest_processor=manifest_processor, schema_validator=validator, - token_refresher=token_refresher, batch_save_enabled=self.batch_save_enabled, save_records_batch_size=self.batch_save_size ) @@ -168,9 +194,7 @@ class ProcessManifestOperatorR3ByReference(BaseOperator, ReceivingContextMixin): manifest_data = self._get_manifest_data_by_reference(context=context, execution_context=execution_context, use_history=False, - config_manager=None, - data_partition_id=None, - token_refresher=token_refresher, + dataset_dms_client=dataset_dms_client, logger=logger) logger.debug(f"Manifest data: {manifest_data}") diff --git a/osdu_airflow/operators/update_status_by_reference.py b/osdu_airflow/operators/update_status_by_reference.py index d63c6db8e3c9ca2dfc09e1f1c902622806849ddb..883b0b581f3fbe0aa4c7fb2b7882f9684fbe7b80 100644 --- a/osdu_airflow/operators/update_status_by_reference.py +++ b/osdu_airflow/operators/update_status_by_reference.py @@ -21,6 +21,11 @@ import logging from typing import Tuple from airflow.models import BaseOperator, Variable +from osdu_api.clients.dataset.dataset_dms_client import DatasetDmsClient +from osdu_api.clients.dataset.dataset_registry_client import DatasetRegistryClient +from osdu_api.clients.entitlements.entitlements_client import EntitlementsClient +from osdu_api.clients.ingestion_workflow.ingestion_workflow_client import IngestionWorkflowClient +from osdu_api.configuration.config_manager import DefaultConfigManager from osdu_ingestion.libs.context import Context from osdu_ingestion.libs.exceptions import PipelineFailedError from osdu_ingestion.libs.refresh_token import AirflowTokenRefresher @@ -32,7 +37,7 @@ from osdu_airflow.operators.mixins.ReceivingContextMixin import ReceivingContext logger = logging.getLogger() -class UpdateStatusOperatorByReference(BaseOperator): +class UpdateStatusOperatorByReference(BaseOperator, ReceivingContextMixin): """Operator to update status.""" ui_color = '#10ECAA' ui_fgcolor = '#000000' @@ -124,35 +129,66 @@ class UpdateStatusOperatorByReference(BaseOperator): workflow_name = conf["workflow_name"] run_id = conf["run_id"] status = self.status.value + + token_refresher = AirflowTokenRefresher() + + workflow_client = IngestionWorkflowClient( + ingestion_workflow_url=Variable.get("core__service__workflow__url", default_var=None), + token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id + ) status_updater = UpdateStatus( + workflow_client=workflow_client, workflow_name=workflow_name, - workflow_url=Variable.get("core__service__workflow__host"), workflow_id="", run_id=run_id, status=status, - token_refresher=AirflowTokenRefresher(), context=payload_context ) + status_updater.update_workflow_status() if self._show_skipped_ids: + config_manager = DefaultConfigManager() + + dataset_dms_client = DatasetDmsClient( + dataset_url=Variable.get("core__service__dataset__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + + dataset_reg_client = DatasetRegistryClient( + dataset_url=Variable.get("core__service__dataset__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + + entitlements_client = EntitlementsClient( + entitlements_url=Variable.get("core__service__entitlements__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + skipped_ids, saved_record_ids = self._create_skipped_report(context) context["ti"].xcom_push(key="skipped_ids", value=skipped_ids) context["ti"].xcom_push(key="saved_record_ids", value=saved_record_ids) - mixin = ReceivingContextMixin() - record_id = mixin._put_file_on_dataset_service( file_content=str(saved_record_ids), - acl_data=mixin._get_default_acl(execution_context=execution_context, - config_manager=None, - data_partition_id=None, - token_refresher=AirflowTokenRefresher(), - logger=logger), - legal_tags=mixin._get_default_legaltags(execution_context=execution_context, - config_manager=None, - data_partition_id=None), - config_manager=None, - data_partition_id=None, - token_refresher=AirflowTokenRefresher(), + record_id = self._put_file_on_dataset_service( file_content=str(saved_record_ids), + acl_data=self._get_default_acl(execution_context=execution_context, + entitlements_client=entitlements_client, + data_partition_id=data_partition_id, + logger=logger), + legal_tags=self._get_default_legaltags(execution_context=execution_context, + data_partition_id=payload_context.data_partition_id), + data_partition_id=payload_context.data_partition_id, + dataset_reg_client=dataset_reg_client, + dataset_dms_client=dataset_dms_client, logger=logger) logger.error(f"#SKIPPED_IDS: Some ids in the manifest were skipped. You can find the report in the datasetService with this record id : {record_id}") diff --git a/osdu_airflow/operators/validate_manifest_schema_by_reference.py b/osdu_airflow/operators/validate_manifest_schema_by_reference.py index c9c5036d9e9f0cd1b5ab6d2c3e25f372f0ad9d64..4edf26b7511588862d38edbfe7ae4224caccb434 100644 --- a/osdu_airflow/operators/validate_manifest_schema_by_reference.py +++ b/osdu_airflow/operators/validate_manifest_schema_by_reference.py @@ -20,6 +20,11 @@ Validate Manifest against R3 schemas operator. import logging from airflow.models import BaseOperator, Variable +from osdu_api.clients.dataset.dataset_dms_client import DatasetDmsClient +from osdu_api.clients.dataset.dataset_registry_client import DatasetRegistryClient +from osdu_api.clients.entitlements.entitlements_client import EntitlementsClient +from osdu_api.clients.schema.schema_client import SchemaClient +from osdu_api.configuration.config_manager import DefaultConfigManager from osdu_ingestion.libs.constants import DATA_TYPES_WITH_SURROGATE_KEYS, SURROGATE_KEYS_PATHS from osdu_ingestion.libs.context import Context from osdu_ingestion.libs.exceptions import EmptyManifestError, GenericManifestSchemaError @@ -60,24 +65,52 @@ class ValidateManifestSchemaOperatorByReference(BaseOperator, ReceivingContextMi execution_context = context["dag_run"].conf["execution_context"] payload_context = Context.populate(execution_context) token_refresher = AirflowTokenRefresher() + config_manager = DefaultConfigManager() logger.debug( f"DATA_TYPES_WITH_SURROGATE_KEYS: {DATA_TYPES_WITH_SURROGATE_KEYS}") logger.debug(f"SURROGATE_KEYS_PATHS: {SURROGATE_KEYS_PATHS}") + schema_client = SchemaClient( + schema_url=Variable.get("core__service__schema__url", default_var=None), + token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id + ) schema_validator = SchemaValidator( - token_refresher, - payload_context, + schema_client=schema_client, + context=payload_context, surrogate_key_fields_paths=SURROGATE_KEYS_PATHS, data_types_with_surrogate_ids=DATA_TYPES_WITH_SURROGATE_KEYS ) - manifest_data = self._get_manifest_data_by_reference(context=context, + dataset_dms_client = DatasetDmsClient( + dataset_url=Variable.get("core__service__dataset__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + + dataset_reg_client = DatasetRegistryClient( + dataset_url=Variable.get("core__service__dataset__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + + entitlements_client = EntitlementsClient( + entitlements_url=Variable.get("core__service__entitlements__url", default_var=None), + config_manager=config_manager, + data_partition_id=payload_context.data_partition_id, + token_refresher=token_refresher, + logger=logger + ) + + manifest_data = self._get_manifest_data_by_reference(context=context, execution_context=execution_context, use_history=True, # use the history because "check_payload_type" does not return the id - config_manager=None, - data_partition_id=None, - token_refresher=token_refresher, + dataset_dms_client=dataset_dms_client, logger=logger) if not manifest_data: @@ -99,7 +132,8 @@ class ValidateManifestSchemaOperatorByReference(BaseOperator, ReceivingContextMi execution_context=execution_context, manifest=valid_manifest_file, use_history=False, - config_manager=None, - data_partition_id=None, - token_refresher=token_refresher, + data_partition_id=payload_context.data_partition_id, + dataset_dms_client=dataset_dms_client, + dataset_reg_client=dataset_reg_client, + entitlements_client=entitlements_client, logger=logger) diff --git a/osdu_airflow/tests/airflow_var.json b/osdu_airflow/tests/airflow_var.json index 60b79819ff8f67d68a1a066792cd1e382072cd03..2e7b9404a878ed83209729e6ccb5b140ba8b47ab 100644 --- a/osdu_airflow/tests/airflow_var.json +++ b/osdu_airflow/tests/airflow_var.json @@ -1,6 +1,7 @@ { "core__service__storage__url": "https://test", "core__service__workflow__host": "https://test", + "core__service__entitlements__url": "https://test", "core__service__file__host": "htpps://test", "core__service__workflow__url": "https://test", "core__service__search__url": "https://test", @@ -10,6 +11,8 @@ "core__service__storage__host": "https://test", "core__service__file__url": "https://test", "core__config__dataload_config_path": "https://test", + "core__service__dataset__url": "https://test/test", "core__auth__access_token": "test", - "core__ingestion__batch_count": 3 + "core__ingestion__batch_count": 3, + "use_new_dataset_service_endpoints": "false" } diff --git a/osdu_airflow/tests/airflow_var.sh b/osdu_airflow/tests/airflow_var.sh new file mode 100644 index 0000000000000000000000000000000000000000..7a8dc5a17c9cd4bc7026f46b045990818ebe7b69 --- /dev/null +++ b/osdu_airflow/tests/airflow_var.sh @@ -0,0 +1,13 @@ +export AIRFLOW_VAR_CORE__SERVICE__STORAGE__URL="https://test" +export AIRFLOW_VAR_CORE__SERVICE__WORKFLOW__HOST="https://test" +export AIRFLOW_VAR_CORE__SERVICE__FILE__HOST="htpps://test" +export AIRFLOW_VAR_CORE__SERVICE__WORKFLOW__URL="https://test" +export AIRFLOW_VAR_CORE__SERVICE__SEARCH__URL="https://test" +export AIRFLOW_VAR_CORE__SERVICE__SEARCH__HOST="https://test" +export AIRFLOW_VAR_CORE__SERVICE__SCHEMA__URL="https://test" +export AIRFLOW_VAR_CORE__SERVICE__SCHEMA__HOST="https://test" +export AIRFLOW_VAR_CORE__SERVICE__STORAGE__HOST="https://test" +export AIRFLOW_VAR_CORE__SERVICE__FILE__URL="https://test" +export AIRFLOW_VAR_CORE__CONFIG__DATALOAD_CONFIG_PATH="https://test" +export AIRFLOW_VAR_CORE__AUTH__ACCESS_TOKEN="test" +export AIRFLOW_VAR_CORE__INGESTION__BATCH_COUNT=3 \ No newline at end of file diff --git a/osdu_airflow/tests/plugin-unit-tests/data/master/test_manifest.json b/osdu_airflow/tests/plugin-unit-tests/data/master/test_manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..4bb1b28c5c015329fd9e1053db949b1586528631 --- /dev/null +++ b/osdu_airflow/tests/plugin-unit-tests/data/master/test_manifest.json @@ -0,0 +1,43 @@ +{ + "kind": "test:test:Manifest:1.0.0", + "ReferenceData": [], + "MasterData": [ + { + "id": "opendes:master-data/Wellbore:350112350400", + "kind": "opendes:osdu:TestMaster:0.3.0", + "groupType": "master-data", + "version": 1, + "acl": { + "owners": [ + "data.default.viewers@opendes.osdu-gcp.go3-nrg.projects.epam.com" + ], + "viewers": [ + "data.default.owners@opendes.osdu-gcp.go3-nrg.projects.epam.com" + ] + }, + "legal": { + "legaltags": [ + "opendes-demo-legaltag" + ], + "otherRelevantDataCountries": [ + "srn:opendes:master-data/GeoPoliticalEntity:USA:" + ], + "status": "srn:opendes:reference-data/LegalStatus:public:1111" + }, + "resourceHostRegionIDs": [ + "srn:opendes:reference-data/OSDURegion:US-EAST:" + ], + "resourceObjectCreationDateTime": "2020-10-16T11:14:45-05:00", + "resourceVersionCreationDateTime": "2020-10-16T11:14:45-05:00", + "resourceSecurityClassification": "srn:opendes:reference-data/ResourceSecurityClassification:public:", + "source": "srn:opendes:master-data/Organisation:Oklahoma Corporation Commission:", + "existenceKind": "srn:opendes:reference-data/ExistenceKind:Active:", + "licenseState": "srn:opendes:reference-data/LicenseState:Unlicensed:", + "data": { + "SequenceNumber": 1 + }, + "schema": "test:test:GenericMasterData:1.0.0" + } + ], + "Data": {} +} diff --git a/osdu_airflow/tests/plugin-unit-tests/file_paths.py b/osdu_airflow/tests/plugin-unit-tests/file_paths.py index e9d6aab8f22125dff678558fb6c0a7e7f84e8c23..ccd34f003ed3f5ec97e7b240ab0aea88b93f28e5 100644 --- a/osdu_airflow/tests/plugin-unit-tests/file_paths.py +++ b/osdu_airflow/tests/plugin-unit-tests/file_paths.py @@ -21,5 +21,6 @@ MANIFEST_GENERIC_SCHEMA_PATH = f"{DATA_PATH_PREFIX}/manifests/schema_Manifest.1. MANIFEST_WELLBORE_VALID_PATH = f"{DATA_PATH_PREFIX}/master/Wellbore.0.3.0.json" MANIFEST_BATCH_WELLBORE_VALID_PATH = f"{DATA_PATH_PREFIX}/master/batch_Wellbore.0.3.0.json" +MANIFEST_TEST_PATH = f"{DATA_PATH_PREFIX}/master/test_manifest.json" SEARCH_VALID_RESPONSE_PATH = f"{DATA_PATH_PREFIX}/other/SearchResponseValid.json" diff --git a/osdu_airflow/tests/plugin-unit-tests/mock_providers.py b/osdu_airflow/tests/plugin-unit-tests/mock_providers.py index 7690ff188e33f83fcc397fd29cfe42df85224159..2d26f3510d2cf3231cec70bee881441eea0d4b79 100644 --- a/osdu_airflow/tests/plugin-unit-tests/mock_providers.py +++ b/osdu_airflow/tests/plugin-unit-tests/mock_providers.py @@ -17,10 +17,11 @@ import io import logging from typing import Tuple + from osdu_api.providers.blob_storage import get_client from osdu_api.providers.credentials import get_credentials from osdu_api.providers.factory import ProvidersFactory -from osdu_api.providers.types import BlobStorageClient, BaseCredentials +from osdu_api.providers.types import BaseCredentials, BlobStorageClient logger = logging.getLogger(__name__) diff --git a/osdu_airflow/tests/plugin-unit-tests/mock_responses.py b/osdu_airflow/tests/plugin-unit-tests/mock_responses.py index 227f859b9445d3d04b008e7e50f266b0a6dacabb..89a092a31f28cab6a50a57f690194a915d0531b4 100644 --- a/osdu_airflow/tests/plugin-unit-tests/mock_responses.py +++ b/osdu_airflow/tests/plugin-unit-tests/mock_responses.py @@ -14,8 +14,9 @@ # limitations under the License. -import json import http +import json + import requests diff --git a/osdu_airflow/tests/plugin-unit-tests/test_operators_r3.py b/osdu_airflow/tests/plugin-unit-tests/test_operators_r3.py index 3c10760db9d6d06b62f2d430938b4955225bd3cd..8f3a39a81e87e3c993c0fe24afe2ae7b00b48f9e 100644 --- a/osdu_airflow/tests/plugin-unit-tests/test_operators_r3.py +++ b/osdu_airflow/tests/plugin-unit-tests/test_operators_r3.py @@ -13,18 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mock_providers - import http import json from datetime import datetime from typing import ClassVar, TypeVar from unittest.mock import MagicMock +import mock_providers import pytest -from osdu_ingestion.libs.exceptions import PipelineFailedError -from osdu_ingestion.libs.handle_file import FileHandler - import requests from airflow import DAG from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator @@ -32,18 +28,20 @@ from airflow.models import TaskInstance from file_paths import (MANIFEST_BATCH_WELLBORE_VALID_PATH, MANIFEST_GENERIC_SCHEMA_PATH, MANIFEST_WELLBORE_VALID_PATH) from mock_responses import MockWorkflowResponse +from osdu_api.clients.search.search_client import SearchClient +from osdu_api.clients.storage.record_client import RecordClient +from osdu_ingestion.libs.exceptions import PipelineFailedError +from osdu_ingestion.libs.handle_file import FileHandler +from osdu_ingestion.libs.segy_conversion_metadata.headers_byte_locations import HeadersByteLocations +from osdu_ingestion.libs.segy_conversion_metadata.open_vds import OpenVDSMetadata + from osdu_airflow.operators.ensure_manifest_integrity import EnsureManifestIntegrityOperator from osdu_airflow.operators.mixins.ReceivingContextMixin import ReceivingContextMixin from osdu_airflow.operators.process_manifest_r3 import (ManifestProcessor, ProcessManifestOperatorR3, SchemaValidator) -from osdu_airflow.operators.segy_open_vds_conversion import KubernetesPodSegyToOpenVDSOperator +from osdu_airflow.operators.segy_open_vds_conversion import KubernetesPodSegyToOpenVDSOperator from osdu_airflow.operators.update_status import UpdateStatusOperator from osdu_airflow.operators.validate_manifest_schema import ValidateManifestSchemaOperator -from osdu_api.clients.storage.record_client import RecordClient -from osdu_api.clients.search.search_client import SearchClient -from osdu_ingestion.libs.segy_conversion_metadata.open_vds import OpenVDSMetadata -from osdu_ingestion.libs.segy_conversion_metadata.headers_byte_locations import HeadersByteLocations - CustomOperator = TypeVar("CustomOperator") diff --git a/osdu_airflow/tests/plugin-unit-tests/test_operators_r3_manifest_by_reference.py b/osdu_airflow/tests/plugin-unit-tests/test_operators_r3_manifest_by_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..c95bf02bd8cdfc7bba9b14cc37d056589abe359e --- /dev/null +++ b/osdu_airflow/tests/plugin-unit-tests/test_operators_r3_manifest_by_reference.py @@ -0,0 +1,200 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import http +import json +from datetime import datetime +from typing import ClassVar, TypeVar +from unittest.mock import Mock + +import mock_providers +import pytest +import requests +from airflow import DAG +from airflow.models import TaskInstance +from file_paths import (MANIFEST_BATCH_WELLBORE_VALID_PATH, MANIFEST_GENERIC_SCHEMA_PATH, + MANIFEST_TEST_PATH, MANIFEST_WELLBORE_VALID_PATH) +from mock_responses import MockWorkflowResponse +from osdu_ingestion.libs.exceptions import PipelineFailedError +from osdu_ingestion.libs.handle_file import FileHandler + +import osdu_airflow.operators.mixins.ReceivingContextMixin as receiving_context +from osdu_airflow.operators.ensure_manifest_integrity_by_reference import \ + EnsureManifestIntegrityOperatorByReference +from osdu_airflow.operators.mixins.ReceivingContextMixin import ReceivingContextMixin +from osdu_airflow.operators.process_manifest_r3_by_reference import ( + ManifestProcessor, ProcessManifestOperatorR3ByReference, SchemaValidator) +from osdu_airflow.operators.update_status_by_reference import UpdateStatusOperatorByReference +from osdu_airflow.operators.validate_manifest_schema_by_reference import \ + ValidateManifestSchemaOperatorByReference + +CustomOperator = TypeVar("CustomOperator") + + +class MockDagRun: + def __init__(self, conf): + self.conf = conf + + +class MockStorageResponse(requests.Response): + + def json(self, **kwargs): + return {"recordIds": ["test"]} + + +class TestManifestByReferenceOperators(object): + + @staticmethod + def _read_manifest(): + with open(MANIFEST_TEST_PATH) as f: + return f.read() + + def _make_services_calls_mocked(self): + mock_retrieval_response = Mock() + mock_retrieval_response.json = Mock( + side_effect=lambda *args, **kwargs: {"delivery": [{"retrievalProperties": {"signedUrl": "test"}}]}) + receiving_context.DatasetDmsClient.get_retrieval_instructions = Mock(side_effect=lambda *args, **kwargs: mock_retrieval_response) + + mock_storage_instructions = Mock() + mock_storage_instructions.json = Mock( + side_effect=lambda *args, **kwargs: { + "storageLocation": { + "signedUrl": "string", + "fileSource": "string", + } + } + ) + receiving_context.DatasetDmsClient.get_storage_instructions = Mock(side_effect=lambda *args, **kwargs: mock_storage_instructions) + + mock_make_request = Mock() + mock_make_request.json = Mock(side_effect=lambda *args, **kwargs: self._read_manifest()) + receiving_context.DatasetDmsClient.make_request = Mock(side_effect=lambda *args, **kwargs: mock_make_request) + + mock_register_dataset = Mock() + mock_register_dataset.json = Mock(side_effect=lambda *args, **kwargs: {'datasetRegistries': [{"id": "test"}]}) + receiving_context.DatasetRegistryClient.register_dataset = Mock(side_effect=lambda *args, **kwargs: mock_register_dataset) + + mock_get_groups_for_user = Mock() + mock_get_groups_for_user.json = Mock( + side_effect=lambda *args, **kwargs: {"groups": []}) + receiving_context.EntitlementsClient.get_groups_for_user = Mock(side_effect=lambda *args, **kwargs: mock_get_groups_for_user) + + def _create_batch_task(self, operator: ClassVar[CustomOperator]) -> (CustomOperator, dict): + self._make_services_calls_mocked() + with open(MANIFEST_BATCH_WELLBORE_VALID_PATH) as f: + conf = json.load(f) + dag = DAG(dag_id='batch_osdu_ingest', start_date=datetime.now()) + task: CustomOperator = operator(dag=dag, task_id='anytask') + ti = TaskInstance(task=task, execution_date=datetime.now()) + + ti.xcom_pull = Mock(side_effect=lambda *args, **kwargs: "test") + context = ti.get_template_context() + context["dag_run"] = MockDagRun(conf) + return task, context + + def _create_task(self, operator: ClassVar[CustomOperator]) -> (CustomOperator, dict): + self._make_services_calls_mocked() + with open(MANIFEST_WELLBORE_VALID_PATH) as f: + conf = json.load(f) + dag = DAG(dag_id='Osdu_ingest', start_date=datetime.now()) + task: CustomOperator = operator(dag=dag, task_id='anytask') + ti = TaskInstance(task=task, execution_date=datetime.now()) + + ti.xcom_pull = Mock(side_effect=lambda *args, **kwargs: "test") + context = ti.get_template_context() + context["dag_run"] = MockDagRun(conf) + return task, context + + def test_process_manifest_r3_operator(self, monkeypatch): + + def _get_common_schema(*args, **kwargs): + with open(MANIFEST_GENERIC_SCHEMA_PATH) as f: + manifest_schema = json.load(f) + return manifest_schema + + + monkeypatch.setattr(SchemaValidator, "get_schema", _get_common_schema) + monkeypatch.setattr(SchemaValidator, "_validate_against_schema", lambda *args, **kwargs: None) + monkeypatch.setattr(SchemaValidator, "validate_manifest", lambda obj, entities: entities) + monkeypatch.setattr(ManifestProcessor, "save_record_to_storage", + lambda obj, headers, request_data: MockStorageResponse()) + monkeypatch.setattr(FileHandler, "upload_file", + lambda *args, **kwargs: "test") + + task, context = self._create_task(ProcessManifestOperatorR3ByReference) + task.pre_execute(context) + task.execute(context) + + def _test_update_status_operator(self, monkeypatch, status: UpdateStatusOperatorByReference.prev_ti_state): + monkeypatch.setattr(UpdateStatusOperatorByReference, "get_previous_ti_statuses", + lambda obj, context: status) + monkeypatch.setattr(requests, "put", lambda *args, **kwargs: MockWorkflowResponse( + status_code=http.HTTPStatus.OK, json="test")) + + task, context = self._create_task(UpdateStatusOperatorByReference) + task.pre_execute(context) + task.execute(context) + + @pytest.mark.parametrize( + "status", + [ + pytest.param( + UpdateStatusOperatorByReference.prev_ti_state.NONE + ), + pytest.param( + UpdateStatusOperatorByReference.prev_ti_state.SUCCESS + ) + ] + ) + def test_update_status_operator(self, monkeypatch, status): + self._test_update_status_operator(monkeypatch, status) + + @pytest.mark.parametrize( + "status", + [ + pytest.param( + UpdateStatusOperatorByReference.prev_ti_state.FAILED + ) + ] + ) + def test_update_status_operator_failed(self, monkeypatch, status): + """ + Test if operator raises PipeLineFailedError if any previous task failed. + """ + with pytest.raises(PipelineFailedError): + self._test_update_status_operator(monkeypatch, status) + + def test_validate_schema_operator(self, monkeypatch): + + def _get_common_schema(*args, **kwargs): + with open(MANIFEST_GENERIC_SCHEMA_PATH) as f: + manifest_schema = json.load(f) + return manifest_schema + + monkeypatch.setattr(SchemaValidator, "get_schema", _get_common_schema) + monkeypatch.setattr(SchemaValidator, "_validate_against_schema", lambda *args, **kwargs: None) + monkeypatch.setattr(SchemaValidator, "validate_manifest", lambda obj, entities: entities) + task, context = self._create_task(ValidateManifestSchemaOperatorByReference) + task.pre_execute(context) + task.execute(context) + + def test_ensure_manifest_integrity(self, monkeypatch): + monkeypatch.setattr(FileHandler, "upload_file", + lambda *args, **kwargs: "test") + monkeypatch.setattr(ReceivingContextMixin, "_get_previously_skipped_entities", + lambda *args, **kwargs: []) + task, context = self._create_task(EnsureManifestIntegrityOperatorByReference) + task.pre_execute(context) + task.execute(context)