Skip to content
Snippets Groups Projects
Commit fab381c4 authored by Siarhei Khaletski (EPAM)'s avatar Siarhei Khaletski (EPAM) :triangular_flag_on_post:
Browse files

Merge branch 'GONRG-3109_move_common_logic' into 'master'

GONRG-3109: Move common logic to osdu-airflow-lib

See merge request !76
parents d217212e 23dcb378
No related branches found
No related tags found
1 merge request!76GONRG-3109: Move common logic to osdu-airflow-lib
Pipeline #65237 passed
Showing
with 8 additions and 1173 deletions
......@@ -66,7 +66,7 @@ unit_tests:
image: eu.gcr.io/osdu-cicd-epam/airflow-python-dags/airflow-python-dags:latest
script:
- chmod +x tests/unit_tests.sh
- tests/./unit_tests.sh || EXIT_CODE=$?
# - tests/./unit_tests.sh || EXIT_CODE=$? #TODO: unit tests moved to airflow-osdu-lib, need to remove `unit_tests` CI step later
- exit ${EXIT_CODE}
# TODO: Think about how rsync must look. At the moment it looks messy.
......
......@@ -13,7 +13,6 @@
* * [Ingestion](#ingestion)
* * [OSDU Python SDK](#osdu-python-sdk)
* [Testing](#testing)
* * [Running Unit Tests](#running-unit-tests)
* * [Running E2E Tests](#running-e2e-tests)
* [Logging](#logging)
* * [Logging Configuration](#logging-configuration)
......@@ -129,15 +128,6 @@ If variable defines URL to internal services it should have suffix which show th
| core__ingestion__batch_save_size | Size of the batch of entities to save in Storage Service|
## Testing
### Running Unit Tests
~~~
tests/./set_airflow_env.sh
~~~
~~~
chmod +x tests/test_dags.sh && tests/./test_dags.sh
~~~
### Running E2E Tests
~~~
tests/./set_airflow_env.sh
......
......@@ -20,9 +20,9 @@ from datetime import timedelta
import airflow
from airflow import DAG
from osdu_airflow.backward_compatibility.default_args import update_default_args
from osdu_manifest.operators.deprecated.update_status import UpdateStatusOperator
from osdu_manifest.operators.process_manifest_r2 import ProcessManifestOperatorR2
from osdu_manifest.operators.search_record_id import SearchRecordIdOperator
from osdu_airflow.operators.deprecated.update_status import UpdateStatusOperator
from osdu_airflow.operators.process_manifest_r2 import ProcessManifestOperatorR2
from osdu_airflow.operators.search_record_id import SearchRecordIdOperator
default_args = {
"start_date": airflow.utils.dates.days_ago(0),
......
......@@ -23,11 +23,11 @@ from airflow.models import Variable
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from osdu_airflow.backward_compatibility.default_args import update_default_args
from osdu_airflow.operators.ensure_manifest_integrity import EnsureManifestIntegrityOperator
from osdu_airflow.operators.process_manifest_r3 import ProcessManifestOperatorR3
from osdu_airflow.operators.update_status import UpdateStatusOperator
from osdu_airflow.operators.validate_manifest_schema import ValidateManifestSchemaOperator
from osdu_api.libs.exceptions import NotOSDUSchemaFormatError
from osdu_manifest.operators.ensure_manifest_integrity import EnsureManifestIntegrityOperator
from osdu_manifest.operators.process_manifest_r3 import ProcessManifestOperatorR3
from osdu_manifest.operators.update_status import UpdateStatusOperator
from osdu_manifest.operators.validate_manifest_schema import ValidateManifestSchemaOperator
BATCH_NUMBER = int(Variable.get("core__ingestion__batch_count", "3"))
PROCESS_SINGLE_MANIFEST_FILE = "process_single_manifest_file_task"
......
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .http_hooks import *
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Http Hooks."""
from airflow.hooks.http_hook import HttpHook
workflow_hook = HttpHook(http_conn_id='workflow', method="POST")
search_http_hook = HttpHook(http_conn_id='search', method="POST")
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Update Status operator."""
import copy
import enum
import logging
from airflow.models import BaseOperator, Variable
from osdu_api.libs.context import Context
from osdu_api.libs.exceptions import PipelineFailedError
from osdu_api.libs.refresh_token import AirflowTokenRefresher
from osdu_api.libs.update_status import UpdateStatus
logger = logging.getLogger()
class UpdateStatusOperator(BaseOperator):
"""Operator to update status."""
ui_color = '#10ECAA'
ui_fgcolor = '#000000'
class prev_ti_state(enum.Enum):
NONE = "running"
SUCCESS = "finished"
FAILED = "failed"
def get_previous_ti_statuses(self, context: dict) -> enum.Enum:
"""Get status of previous tasks' executions.
Return corresponding enum value.
:param context: Airflow context
:type context: dict
:return: Previous status
:rtype: enum.Enum
"""
dagrun = context['ti'].get_dagrun()
failed_ti = dagrun.get_task_instances(state='failed')
success_ti = dagrun.get_task_instances(state='success')
if not failed_ti and not success_ti: # There is no prev task so it can't have been failed
logger.info("There are no tasks before this one. So it has status RUNNING")
return self.prev_ti_state.NONE
if failed_ti:
logger.info("There are failed tasks before this one. So it has status FAILED")
return self.prev_ti_state.FAILED
logger.info("There are successed tasks before this one. So it has status SUCCESSED")
return self.prev_ti_state.SUCCESS
def pre_execute(self, context: dict):
self.status = self.get_previous_ti_statuses(context)
def execute(self, context: dict):
"""Execute update workflow status.
If status assumed to be FINISHED then we check whether records
are searchable or not.
If they are then update status FINISHED else FAILED
:param context: Airflow context
:type context: dict
:raises PipelineFailedError: If any of the status is failed
"""
conf = copy.deepcopy(context["dag_run"].conf)
logger.debug(f"Got conf {conf}.")
if "Payload" in conf:
payload_context = Context.populate(conf)
else:
payload_context = Context(data_partition_id=conf["data-partition-id"],
app_key=conf.get("AppKey", ""))
workflow_id = conf["WorkflowID"]
status = self.status.value
status_updater = UpdateStatus(
workflow_name="",
workflow_url=Variable.get("core__service__workflow__url"),
workflow_id=workflow_id,
run_id="",
status=status,
token_refresher=AirflowTokenRefresher(),
context=payload_context
)
status_updater.update_workflow_status()
if self.status is self.prev_ti_state.FAILED:
raise PipelineFailedError("Dag failed")
# Copyright 2021 Google LLC
# Copyright 2021 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""R3 Validate reference Manifest operator."""
import logging
from airflow.models import BaseOperator, Variable
from osdu_airflow.backward_compatibility.airflow_utils import apply_defaults
from osdu_api.libs.context import Context
from osdu_api.libs.refresh_token import AirflowTokenRefresher
from osdu_api.libs.validation.validate_file_source import FileSourceValidator
from osdu_api.libs.validation.validate_referential_integrity import ManifestIntegrity
from osdu_manifest.operators.mixins.ReceivingContextMixin import ReceivingContextMixin
logger = logging.getLogger()
class EnsureManifestIntegrityOperator(BaseOperator, ReceivingContextMixin):
"""Operator to validate ref inside manifest R3 and remove invalid entities."""
ui_color = '#dad5ff'
ui_fgcolor = '#000000'
@apply_defaults
def __init__(self,
previous_task_id: str=None,
*args, **kwargs):
"""Init base operator and obtain base urls from Airflow Variables."""
super().__init__(*args, **kwargs)
self.search_url = Variable.get('core__service__search__url')
self.whitelist_ref_patterns = Variable.get('core__config__reference_patterns_whitelist', default_var=None)
self.previous_task_id = previous_task_id
self._show_skipped_ids = Variable.get(
'core__config__show_skipped_ids', default_var=False
)
def execute(self, context: dict):
"""Execute manifest validation then process it.
:param context: Airflow context
:type context: dict
"""
payload_context = Context.populate(context["dag_run"].conf["execution_context"])
token_refresher = AirflowTokenRefresher()
file_source_validator = FileSourceValidator()
manifest_integrity = ManifestIntegrity(
self.search_url,
token_refresher,
file_source_validator,
payload_context,
self.whitelist_ref_patterns,
)
execution_context = context["dag_run"].conf["execution_context"]
manifest_data = self._get_manifest_data(context, execution_context)
previously_skipped_entities = self._get_previously_skipped_entities(context)
logger.debug(f"Manifest data: {manifest_data}")
manifest, skipped_ids = manifest_integrity.ensure_integrity(
manifest_data,
previously_skipped_entities
)
logger.debug(f"Valid manifest data: {manifest_data}")
if self._show_skipped_ids:
context["ti"].xcom_push(key="skipped_ids", value=skipped_ids)
return {"manifest": manifest}
# Copyright 2021 Google LLC
# Copyright 2021 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from osdu_api.libs.types import ManifestType
class ReceivingContextMixin:
"""Mixin for receiving manifest file from XCOMs in case if current operator not the first in the row"""
def _get_manifest_data(self, context: dict, execution_context: dict) -> ManifestType:
"""
Receive manifest file. If previous task id not None - get manifest file from XCOMs.
Otherwise - get manifest file from execution context
"""
if self.previous_task_id:
previous_task_value = context["ti"].xcom_pull(task_ids=self.previous_task_id,
key="return_value")
if previous_task_value:
manifest_data = previous_task_value["manifest"]
else:
manifest_data = execution_context["manifest"]
else:
manifest_data = execution_context["manifest"]
return manifest_data
def _get_previously_skipped_entities(self, context: dict) -> list:
"""
Receive skipped entities from previous tasks.
"""
previously_skipped_ids = []
dagrun = context['ti'].get_dagrun()
task_instances = dagrun.get_task_instances()
for task in task_instances:
task_skipped_ids = context["ti"].xcom_pull(key="skipped_ids", task_ids=task.task_id)
if task_skipped_ids:
previously_skipped_ids.extend(task_skipped_ids)
return previously_skipped_ids
# Copyright 2021 Google LLC
# Copyright 2021 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""R2 Process Manifest operator."""
import configparser
import enum
import json
import logging
import re
import sys
import time
import uuid
from collections import Counter
from typing import Tuple
from urllib.error import HTTPError
import requests
import tenacity
from airflow.models import BaseOperator, Variable
from osdu_api.libs.auth.authorization import authorize
from osdu_api.libs.refresh_token import AirflowTokenRefresher
config = configparser.RawConfigParser()
config.read(Variable.get("core__config__dataload_config_path"))
DEFAULT_TENANT = config.get("DEFAULTS", "tenant")
DEFAULT_SOURCE = config.get("DEFAULTS", "authority")
DEFAULT_VERSION = config.get("DEFAULTS", "kind_version")
RETRIES = 3
TIMEOUT = 1
# Set up base logger
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s"))
logger = logging.getLogger("Dataload")
logger.setLevel(logging.INFO)
logger.addHandler(handler)
# Some constants, used by script
SEARCH_OK_RESPONSE_CODES = [200]
DATA_LOAD_OK_RESPONSE_CODES = [201]
NOT_FOUND_RESPONSE_CODES = [404]
BAD_TOKEN_RESPONSE_CODES = [400, 401, 403, 500]
class FileType(enum.Enum):
MANIFEST = enum.auto()
WORKPRODUCT = enum.auto()
def dataload(**kwargs):
data_conf = kwargs['dag_run'].conf
conf_payload = kwargs["dag_run"].conf["Payload"]
loaded_conf = {
"acl": conf_payload["acl"],
"legal_tag": conf_payload["legal"],
"data_object": data_conf
}
return loaded_conf, conf_payload
def create_headers(conf_payload):
"""Create header.
:param conf_payload: config payload
:return: headers
"""
partition_id = conf_payload["data-partition-id"]
app_key = conf_payload["AppKey"]
headers = {
'Content-type': 'application/json',
'data-partition-id': partition_id,
'AppKey': app_key
}
return headers
def generate_id(type_id):
"""Generate resource ID.
:param type_id: resource type ID
:return: resource ID
"""
return "{0}{1}:".format(type_id.replace("type:", ""), re.sub(r"\D", "", str(uuid.uuid4())))
def determine_data_type(raw_resource_type_id):
"""Determine resource type ID.
:param raw_resource_type_id: raw resource type ID from manifest file
:return: short resource type ID
"""
return raw_resource_type_id.split("/")[-1].replace(":", "") \
if raw_resource_type_id is not None else None
# TODO: add comments to functions that implement actions in this function
def process_file_items(loaded_conf, conf_payload) -> Tuple[list, list]:
"""Process files items.
:param loaded_conf: loaded configuration
:param conf_payload: configuration payload
:return: list of file records and list of their ids
"""
file_ids = []
file_list = []
data_object = loaded_conf.get("data_object")
acl = loaded_conf.get("acl")
legal_tag = loaded_conf.get("legal_tag")
for file in data_object["Files"]:
file["ResourceID"] = generate_id(file["ResourceTypeID"])
file_ids.append(file["ResourceID"])
file_list.append(
(
populate_request_body(file, acl, legal_tag, "file", conf_payload),
"File"
)
)
return file_list, file_ids
def process_wpc_items(loaded_conf, product_type, file_ids, conf_payload):
"""Process WorkProductComponents items.
:param loaded_conf: loaded configuration
:param product_type: product type
:param file_ids: list of file ids
:param conf_payload: configuration payload
:return: list of workproductcomponents records and list of their ids
"""
wpc_ids = []
wpc_list = []
data_object = loaded_conf.get("data_object")
acl = loaded_conf.get("acl")
legal_tag = loaded_conf.get("legal_tag")
for wpc in data_object["WorkProductComponents"]:
wpc["ResourceID"] = generate_id(wpc["ResourceTypeID"])
wpc_ids.append(wpc["ResourceID"])
wpc["Data"]["GroupTypeProperties"]["Files"] = file_ids
wpc_list.append(
(
populate_request_body(wpc, acl, legal_tag, product_type + "_wpc", conf_payload),
product_type + "_wpc"
)
)
return wpc_list, wpc_ids
def process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload) -> list:
"""Process WorkProduct item.
:param loaded_conf: loaded configuration
:param product_type: product type
:param wpc_ids: work product component ids
:param conf_payload: configuration payload
:return: work product record
"""
data_object = loaded_conf.get("data_object")
acl = loaded_conf.get("acl")
legal_tag = loaded_conf.get("legal_tag")
work_product = data_object["WorkProduct"]
work_product["ResourceID"] = generate_id(work_product["ResourceTypeID"])
work_product["Data"]["GroupTypeProperties"]["Components"] = wpc_ids
work_product = [
(
populate_request_body(work_product, acl, legal_tag, product_type + "_wp", conf_payload),
product_type + "_wp"
)
]
return work_product
def validate_file_type(file_type, data_object):
"""Validate file type.
:param file_type: file type
:param data_object: file record
"""
if not file_type:
logger.error(f"Error with file {data_object}. Type could not be specified.")
sys.exit(2)
def validate_file(loaded_conf) -> Tuple[FileType, str]:
"""Validate file.
:param loaded_conf: loaded configuration
:return: file type and produc type
"""
data_object = loaded_conf.get("data_object")
if not data_object:
logger.error(f"Error with file {data_object}. It is empty.")
sys.exit(2)
elif "Manifest" in data_object and "ResourceTypeID" in data_object.get("Manifest"):
product_type = determine_data_type(data_object["Manifest"].get("ResourceTypeID"))
validate_file_type(product_type, data_object)
return (FileType.MANIFEST, product_type)
elif "WorkProduct" in data_object and "ResourceTypeID" in data_object.get("WorkProduct"):
product_type = determine_data_type(data_object["WorkProduct"].get("ResourceTypeID"))
validate_file_type(product_type, data_object)
if product_type.lower() == "workproduct" and \
data_object.get("WorkProductComponents") and \
len(data_object["WorkProductComponents"]) >= 1:
product_type = determine_data_type(
data_object["WorkProductComponents"][0].get("ResourceTypeID"))
validate_file_type(product_type, data_object)
return (FileType.WORKPRODUCT, product_type)
else:
logger.error(
f"Error with file {data_object}. It doesn't have either Manifest or WorkProduct or ResourceTypeID.")
sys.exit(2)
def create_kind(data_kind, conf_payload):
"""Create kind.
:param data_kind: data kind
:param conf_payload: configuration payload
:return: kind
"""
partition_id = conf_payload.get("data-partition-id", DEFAULT_TENANT)
source = conf_payload.get("authority", DEFAULT_SOURCE)
version = conf_payload.get("kind_version", DEFAULT_VERSION)
kind_init = config.get("KINDS_INITIAL", f"{data_kind.lower()}_kind")
kind = f"{partition_id}:{source}:{kind_init}:{version}"
return kind
def populate_request_body(data, acl, legal_tag, data_type, conf_payload):
"""Populate request body according API specification
:param data: item data from manifest files
:param data_type: resource type ID
:return: populated request
:rtype: dict
"""
request = {"kind": create_kind(data_type, conf_payload),
"legal": {
"legaltags": [],
"otherRelevantDataCountries": ["US"],
"status": "compliant"
},
"acl": {
"viewers": [],
"owners": []
},
"data": data}
request["legal"]["legaltags"] = legal_tag["legaltags"]
request["acl"]["viewers"] = acl["viewers"]
request["acl"]["owners"] = acl["owners"]
return request
def separate_type_data(request_data):
"""Separate the list of tuples into Data Type Counter and data list
:param request_data: tuple of data and types
:type request_data: tuple(list, str)
:return: counter with data types and data list
:rtype: tuple(counter, list)
"""
data = []
types = Counter()
for elem in request_data:
data.append(elem[0])
types[elem[1]] += 1
logger.info(f"The count of records to be ingested: {str(dict(types))}")
return types, data
def create_manifest_request_data(loaded_conf: dict, product_type: str):
acl = loaded_conf.get("acl")
legal_tag = loaded_conf.get("legal_tag")
data_object = loaded_conf.get("data_object")
data_objects_list = [
(
populate_request_body(data_object["Manifest"], acl, legal_tag, product_type),
product_type)]
return data_objects_list
def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wpc_list, file_list):
data_object_list = file_list + wpc_list + wp
types, data_objects_list = separate_type_data(data_object_list)
return data_objects_list
@tenacity.retry(
wait=tenacity.wait_fixed(TIMEOUT),
stop=tenacity.stop_after_attempt(RETRIES),
reraise=True
)
@authorize(AirflowTokenRefresher())
def send_request(headers, request_data):
"""Send request to records storage API."""
logger.error(f"Header {str(headers)}")
# loop for implementing retries send process
retries = RETRIES
for retry in range(retries):
try:
# send batch request for creating records
response = requests.put(Variable.get('core__service__storage__url'), json.dumps(request_data),
headers=headers)
if response.status_code in DATA_LOAD_OK_RESPONSE_CODES:
logger.info(",".join(map(str, response.json()["recordIds"])))
return response
reason = response.text[:250]
logger.error(f"Request error.")
logger.error(f"Response status: {response.status_code}. "
f"Response content: {reason}.")
if retry + 1 < retries:
if response.status_code in BAD_TOKEN_RESPONSE_CODES:
logger.error("Invalid or expired token.")
return response
else:
time_to_sleep = TIMEOUT
logger.info(f"Retrying in {time_to_sleep} seconds...")
time.sleep(time_to_sleep)
except (requests.RequestException, HTTPError) as exc:
logger.error(f"Unexpected request error. Reason: {exc}")
sys.exit(2)
# End script if ran out of retries and data could not be uploaded.
else:
logger.error(f"Request could not be completed.\n"
f"Reason: {reason}")
sys.exit(2)
def process_manifest(**kwargs):
"""Process manifest."""
loaded_conf, conf_payload = dataload(**kwargs)
file_type, product_type = validate_file(loaded_conf)
if file_type is FileType.MANIFEST:
manifest_record = create_manifest_request_data(loaded_conf, product_type)
elif file_type is FileType.WORKPRODUCT:
file_list, file_ids = process_file_items(loaded_conf, conf_payload)
kwargs["ti"].xcom_push(key="file_ids", value=file_ids)
wpc_list, wpc_ids = process_wpc_items(loaded_conf, product_type, file_ids, conf_payload)
wp_list = process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload)
manifest_record = create_workproduct_request_data(loaded_conf, product_type, wp_list,
wpc_list,
file_list)
else:
sys.exit(2)
headers = create_headers(conf_payload)
record_ids = send_request(headers, request_data=manifest_record).json()["recordIds"]
kwargs["ti"].xcom_push(key="record_ids", value=record_ids)
class ProcessManifestOperatorR2(BaseOperator):
"""R2 Manifest Operator."""
ui_color = '#dad5ff'
ui_fgcolor = '#000000'
def execute(self, context):
process_manifest(**context)
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
R3 Process Manifest operator.
"""
import logging
from math import ceil
from typing import List, Tuple
from airflow.models import BaseOperator, Variable
from jsonschema import SchemaError
from osdu_api.libs.constants import DATA_TYPES_WITH_SURROGATE_KEYS, SURROGATE_KEYS_PATHS
from osdu_api.libs.context import Context
from osdu_api.libs.exceptions import (EmptyManifestError, GenericManifestSchemaError,
GetSchemaError, NotOSDUSchemaFormatError, UploadFileError)
from osdu_api.libs.handle_file import FileHandler
from osdu_api.libs.process_manifest_r3 import ManifestProcessor
from osdu_api.libs.processors.single_manifest_processor import SingleManifestProcessor
from osdu_api.libs.refresh_token import AirflowTokenRefresher
from osdu_api.libs.search_client import SearchClient
from osdu_api.libs.source_file_check import SourceFileChecker
from osdu_api.libs.types import ManifestType
from osdu_api.libs.validation.validate_file_source import FileSourceValidator
from osdu_api.libs.validation.validate_referential_integrity import ManifestIntegrity
from osdu_api.libs.validation.validate_schema import SchemaValidator
from osdu_airflow.backward_compatibility.airflow_utils import apply_defaults
from osdu_manifest.operators.mixins.ReceivingContextMixin import ReceivingContextMixin
from requests import HTTPError
logger = logging.getLogger()
class ProcessManifestOperatorR3(BaseOperator, ReceivingContextMixin):
"""Operator to process manifest R3."""
ui_color = '#dad5ff'
ui_fgcolor = '#000000'
@apply_defaults
def __init__(self, previous_task_id: str = None, batch_number=3, *args, **kwargs):
"""Init base operator and obtain base urls from Airflow Variables."""
super().__init__(*args, **kwargs)
self.previous_task_id = previous_task_id
self.batch_number = batch_number
self.schema_service_url = Variable.get('core__service__schema__url')
self.search_service_url = Variable.get('core__service__search__url')
self.storage_url = Variable.get('core__service__storage__url')
self.file_service_host = Variable.get('core__service__file__host')
self.batch_count = int(Variable.get("core__ingestion__batch_count", "3"))
self._show_skipped_ids = Variable.get('core__config__show_skipped_ids', default_var=False)
def _get_manifest_files_range(self, manifests: List[dict]) -> Tuple[int, int]:
"""
Get start and end indexes of a manifest files slice to be processed within this task.
:param manifests: A list of manifests
:return: start index, end index
"""
split_size = ceil(len(manifests) / self.batch_count)
slice_start_index = (self.batch_number - 1) * split_size
slice_end_index = self.batch_number * split_size
return slice_start_index, slice_end_index
def _process_manifest(
self,
single_manifest_processor: SingleManifestProcessor,
manifest: ManifestType
) -> Tuple[List[str], List[dict]]:
"""
:param single_manifest_processor: Object to process a single manifest file.
Processing includes validation against schemas, storing records enc.
:param manifest: A single manifest file or a list of them.
:return:
"""
skipped_entities = []
if isinstance(manifest, dict):
record_ids, skipped_entities = single_manifest_processor.process_manifest(
manifest, False)
elif isinstance(manifest, list):
record_ids = []
slice_start_index, slice_end_index = self._get_manifest_files_range(manifest)
logger.debug(f"Start and indexes {slice_start_index}:{slice_end_index}")
for single_manifest in manifest[slice_start_index:slice_end_index]:
logger.debug(f"processing {single_manifest}")
try:
saved_records_ids, not_saved_records = single_manifest_processor.process_manifest(
single_manifest, True
)
record_ids.extend(saved_records_ids)
skipped_entities.extend(not_saved_records)
except (UploadFileError, HTTPError, GetSchemaError, SchemaError,
GenericManifestSchemaError) as e:
logger.error(f"Can't process {single_manifest}")
logger.error(e)
continue
else:
raise NotOSDUSchemaFormatError(
f"Manifest {manifest} must be either not empty 'list' or 'dict'")
return record_ids, skipped_entities
def execute(self, context: dict):
"""Execute manifest validation then process it.
Get a single manifest file or a list of them.
If it is a list, calculate which range (slice) of manifest files must be processed and then
process this range one by one.
:param context: Airflow context
:type context: dict
"""
execution_context = context["dag_run"].conf["execution_context"]
payload_context = Context.populate(execution_context)
token_refresher = AirflowTokenRefresher()
file_handler = FileHandler(self.file_service_host, token_refresher, payload_context)
file_source_validator = FileSourceValidator()
search_client = SearchClient(self.search_service_url, token_refresher, payload_context)
source_file_checker = SourceFileChecker()
referential_integrity_validator = ManifestIntegrity(
self.search_service_url,
token_refresher,
file_source_validator,
payload_context
)
manifest_processor = ManifestProcessor(
storage_url=self.storage_url,
file_handler=file_handler,
token_refresher=token_refresher,
context=payload_context,
source_file_checker=source_file_checker,
)
validator = SchemaValidator(
self.schema_service_url,
token_refresher,
payload_context,
data_types_with_surrogate_ids=DATA_TYPES_WITH_SURROGATE_KEYS,
surrogate_key_fields_paths=SURROGATE_KEYS_PATHS
)
single_manifest_processor = SingleManifestProcessor(
storage_url=self.storage_url,
payload_context=payload_context,
referential_integrity_validator=referential_integrity_validator,
manifest_processor=manifest_processor,
schema_validator=validator,
token_refresher=token_refresher,
)
manifest_data = self._get_manifest_data(context, execution_context)
logger.debug(f"Manifest data: {manifest_data}")
if not manifest_data:
raise EmptyManifestError(
f"Data {context['dag_run'].conf} doesn't contain 'manifest field'")
record_ids, skipped_ids = self._process_manifest(single_manifest_processor, manifest_data)
logger.info(f"Processed ids {record_ids}")
context["ti"].xcom_push(key="record_ids", value=record_ids)
if self._show_skipped_ids:
context["ti"].xcom_push(key="skipped_ids", value=skipped_ids)
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from airflow.models import BaseOperator, Variable
from osdu_api.libs.context import Context
from osdu_api.libs.refresh_token import AirflowTokenRefresher
from osdu_api.libs.search_record_ids import SearchId
logger = logging.getLogger()
class SearchRecordIdOperator(BaseOperator):
"""Operator to search files in SearchService by record ids.
Expects "record_ids" field in xcom.
"""
ui_color = '#10ECAA'
ui_fgcolor = '#000000'
FINISHED_STATUS = "finished"
RUNNING_STATUS = "running"
FAILED_STATUS = "failed"
def execute(self, context: dict):
"""Execute update workflow status.
If status assumed to be FINISHED then we check whether proceed files
are searchable or not.
If they are then update status FINISHED else FAILED
:param context: Airflow dagrun context
:type context: dict
"""
payload_context = Context.populate(context["dag_run"].conf)
record_ids = context["ti"].xcom_pull(key="record_ids", )
ids_searcher = SearchId(Variable.get("core__service__search__url"), record_ids, AirflowTokenRefresher(),
payload_context)
ids_searcher.check_records_searchable()
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Update Status operator."""
import copy
import enum
import logging
from typing import Tuple
from airflow.models import BaseOperator, Variable
from osdu_airflow.backward_compatibility.airflow_utils import apply_defaults
from osdu_api.libs.context import Context
from osdu_api.libs.exceptions import PipelineFailedError
from osdu_api.libs.refresh_token import AirflowTokenRefresher
from osdu_api.libs.update_status import UpdateStatus
logger = logging.getLogger()
class UpdateStatusOperator(BaseOperator):
"""Operator to update status."""
ui_color = '#10ECAA'
ui_fgcolor = '#000000'
@apply_defaults
def __init__(self, *args, **kwargs):
"""Init base operator and obtain base urls from Airflow Variables."""
super().__init__(*args, **kwargs)
self._show_skipped_ids = Variable.get('core__config__show_skipped_ids', default_var=False)
class prev_ti_state(enum.Enum):
NONE = "running"
SUCCESS = "finished"
FAILED = "failed"
def get_previous_ti_statuses(self, context: dict) -> enum.Enum:
"""Get status of previous tasks' executions.
Return corresponding enum value.
:param context: Airflow context
:type context: dict
:return: Previous status
:rtype: enum.Enum
"""
dagrun = context['ti'].get_dagrun()
failed_ti = dagrun.get_task_instances(state='failed')
success_ti = dagrun.get_task_instances(state='success')
if not failed_ti and not success_ti: # There is no prev task so it can't have been failed
logger.info("There are no tasks before this one. So it has status RUNNING")
return self.prev_ti_state.NONE
if failed_ti:
logger.info("There are failed tasks before this one. So it has status FAILED")
return self.prev_ti_state.FAILED
logger.info("There are successed tasks before this one. So it has status SUCCESSED")
return self.prev_ti_state.SUCCESS
def pre_execute(self, context: dict):
self.status = self.get_previous_ti_statuses(context)
def _create_skipped_report(self, context: dict) -> Tuple[dict, dict]:
"""
Return aggregated report of skipped ids grouoped by tasks
:param context:
:return: Aggregated report grouped by tasks
"""
skipped_ids_report = {}
saved_record_ids = {}
dagrun = context['ti'].get_dagrun()
task_instances = dagrun.get_task_instances()
for task in task_instances:
task_skipped_ids = context["ti"].xcom_pull(key="skipped_ids", task_ids=task.task_id)
if task_skipped_ids:
skipped_ids_report[task.task_id] = task_skipped_ids
for task in task_instances:
task_saved_ids = context["ti"].xcom_pull(key="record_ids", task_ids=task.task_id)
if task_saved_ids:
saved_record_ids[task.task_id] = task_saved_ids
return skipped_ids_report, saved_record_ids
def execute(self, context: dict):
"""Execute update workflow status.
If status assumed to be FINISHED then we check whether records
are searchable or not.
If they are then update status FINISHED else FAILED
:param context: Airflow context
:type context: dict
:raises PipelineFailedError: If any of the status is failed
"""
conf = copy.deepcopy(context["dag_run"].conf)
logger.debug(f"Got conf {conf}.")
execution_context = conf["execution_context"]
if "Payload" in execution_context:
payload_context = Context.populate(execution_context)
else:
payload_context = Context(data_partition_id=execution_context["data-partition-id"],
app_key=execution_context.get("AppKey", ""))
workflow_name = conf["workflow_name"]
run_id = conf["run_id"]
status = self.status.value
status_updater = UpdateStatus(
workflow_name=workflow_name,
workflow_url=Variable.get("core__service__workflow__host"),
workflow_id="",
run_id=run_id,
status=status,
token_refresher=AirflowTokenRefresher(),
context=payload_context
)
status_updater.update_workflow_status()
if self._show_skipped_ids:
skipped_ids, saved_record_ids = self._create_skipped_report(context)
context["ti"].xcom_push(key="skipped_ids", value=skipped_ids)
context["ti"].xcom_push(key="saved_record_ids", value=saved_record_ids)
if self.status is self.prev_ti_state.FAILED:
raise PipelineFailedError("Dag failed")
# Copyright 2021 Google LLC
# Copyright 2021 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validate Manifest against R3 schemas operator.
"""
import logging
from airflow.models import BaseOperator, Variable
from osdu_airflow.backward_compatibility.airflow_utils import apply_defaults
from osdu_api.libs.constants import DATA_TYPES_WITH_SURROGATE_KEYS, SURROGATE_KEYS_PATHS
from osdu_api.libs.context import Context
from osdu_api.libs.exceptions import EmptyManifestError, GenericManifestSchemaError
from osdu_api.libs.refresh_token import AirflowTokenRefresher
from osdu_api.libs.validation.validate_schema import SchemaValidator
from osdu_manifest.operators.mixins.ReceivingContextMixin import ReceivingContextMixin
logger = logging.getLogger()
class ValidateManifestSchemaOperator(BaseOperator, ReceivingContextMixin):
"""Operator to validate manifest against definition schemasR3."""
ui_color = '#dad5ff'
ui_fgcolor = '#000000'
@apply_defaults
def __init__(self, previous_task_id: str = None, *args, **kwargs):
"""Init base operator and obtain base urls from Airflow Variables."""
super().__init__(*args, **kwargs)
self.previous_task_id = previous_task_id
self.schema_service_url = Variable.get('core__service__schema__url')
self._show_skipped_ids = Variable.get(
'core__config__show_skipped_ids', default_var=False
)
def execute(self, context: dict):
"""Execute manifest validation then process it.
Get a single manifest file or a list of them.
If it is a list, calculate which range (slice) of manifest files must be processed and then
process this range one by one.
:param context: Airflow context
:type context: dict
"""
execution_context = context["dag_run"].conf["execution_context"]
payload_context = Context.populate(execution_context)
token_refresher = AirflowTokenRefresher()
schema_validator = SchemaValidator(
self.schema_service_url,
token_refresher,
payload_context,
surrogate_key_fields_paths=SURROGATE_KEYS_PATHS,
data_types_with_surrogate_ids=DATA_TYPES_WITH_SURROGATE_KEYS
)
manifest_data = self._get_manifest_data(context, execution_context)
logger.debug(f"Manifest data: {manifest_data}")
if not manifest_data:
raise EmptyManifestError(
f"Data {context['dag_run'].conf} doesn't contain 'manifest field'")
_ = schema_validator.validate_common_schema(manifest_data)
try:
valid_manifest_file, skipped_entities = schema_validator.ensure_manifest_validity(
manifest_data
)
except GenericManifestSchemaError as err:
context["ti"].xcom_push(key="skipped_ids", value=str(err))
raise err
if self._show_skipped_ids:
context["ti"].xcom_push(key="skipped_ids", value=skipped_entities)
return {"manifest": valid_manifest_file}
```
pip install pytest
export AIRFLOW_SRC_DIR=/path/to/airflow-folder
pytest
```
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .process_manifest_r2_op import *
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment