diff --git a/src/dags/libs/create_records.py b/src/dags/libs/create_records.py index 90f3a8fb570d1e0122df8ff9db348c9bb4e92001..b8446cda589a74e92d745fa857b1f36908ea0a19 100644 --- a/src/dags/libs/create_records.py +++ b/src/dags/libs/create_records.py @@ -17,6 +17,7 @@ import configparser import logging from airflow.models import Variable +from libs.refresh_token import refresh_token from osdu_api.model.acl import Acl from osdu_api.model.legal import Legal from osdu_api.model.legal_compliance import LegalCompliance @@ -26,7 +27,6 @@ from osdu_api.storage.record_client import RecordClient logger = logging.getLogger() - ACL_DICT = eval(Variable.get("acl")) LEGAL_DICT = eval(Variable.get("legal")) @@ -38,6 +38,12 @@ DEFAULT_SOURCE = config.get("DEFAULTS", "authority") DEFAULT_VERSION = config.get("DEFAULTS", "kind_version") +@refresh_token +def create_update_record_request(headers, record_client, record): + resp = record_client.create_update_records([record], headers.items()) + return resp + + def create_records(**kwargs): # the only way to pass in values through the experimental api is through # the conf parameter @@ -69,14 +75,13 @@ def create_records(**kwargs): headers = { "content-type": "application/json", "slb-data-partition-id": data_conf.get("partition-id", DEFAULT_SOURCE), - "Authorization": f"{auth}", "AppKey": data_conf.get("app-key", "") } record_client = RecordClient() record_client.data_partition_id = data_conf.get( "partition-id", DEFAULT_SOURCE) - resp = record_client.create_update_records([record], headers.items()) + resp = create_update_record_request(headers, record_client, record) logger.info(f"Response: {resp.text}") kwargs["ti"].xcom_push(key="record_ids", value=resp.json()["recordIds"]) return {"response_status": resp.status_code} diff --git a/src/dags/osdu-ingest.py b/src/dags/osdu-ingest.py index 9c602285728f71d88d49bc971c80a0f110b01c0f..f43fb22358f21742d45da7d21a5d0afb1addd017 100644 --- a/src/dags/osdu-ingest.py +++ b/src/dags/osdu-ingest.py @@ -58,7 +58,8 @@ process_manifest_op = ProcessManifestOperator( search_record_ids_op = SearchRecordIdOperator( task_id="search_record_ids_task", provide_context=True, - dag=dag + dag=dag, + retries=1 ) update_status_running_op >> process_manifest_op >> \ diff --git a/src/plugins/libs/exceptions.py b/src/plugins/libs/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..446ec6ea1f6662f1e182443d1264ed81be6d55f0 --- /dev/null +++ b/src/plugins/libs/exceptions.py @@ -0,0 +1,35 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class RecordsNotSearchableError(Exception): + """ + Raise when expected totalCount of records differs from actual one. + """ + pass + + +class RefreshSATokenError(Exception): + """ + Raise when token is empty after attempt to get credentials from service account file. + """ + pass + + +class PipelineFailedError(Exception): + """ + Raise when pipeline failed. + """ + pass diff --git a/src/plugins/libs/refresh_token.py b/src/plugins/libs/refresh_token.py new file mode 100644 index 0000000000000000000000000000000000000000..384e475730c090a4e52ec990a11600b324343685 --- /dev/null +++ b/src/plugins/libs/refresh_token.py @@ -0,0 +1,134 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +from http import HTTPStatus + +import requests +from airflow.models import Variable +from google.auth.transport.requests import Request +from google.oauth2 import service_account +from libs.exceptions import RefreshSATokenError +from tenacity import retry, stop_after_attempt + +ACCESS_TOKEN = None + +# Set up base logger +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter( + logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) +logger = logging.getLogger("Dataload") +logger.setLevel(logging.INFO) +logger.addHandler(handler) + +RETRIES = 3 + +SA_FILE_PATH = Variable.get("sa-file-osdu") +ACCESS_SCOPES = ['openid', 'email', 'profile'] + + +@retry(stop=stop_after_attempt(RETRIES)) +def get_access_token(sa_file: str, scopes: list) -> str: + """ + Refresh access token. + """ + try: + credentials = service_account.Credentials.from_service_account_file( + sa_file, scopes=scopes) + except ValueError as e: + logger.error("SA file has bad format.") + raise e + logger.info("Refresh token.") + credentials.refresh(Request()) + token = credentials.token + if credentials.token is None: + logger.error("Can't refresh token using SA-file. Token is empty.") + raise RefreshSATokenError + logger.info("Token refreshed.") + return token + + +@retry(stop=stop_after_attempt(RETRIES)) +def set_access_token(sa_file: str, scopes: list) -> str: + """ + Create token + """ + global ACCESS_TOKEN + token = get_access_token(sa_file, scopes) + auth = f"Bearer {token}" + ACCESS_TOKEN = token + Variable.set("access_token", token) + return auth + + +def _check_token(): + global ACCESS_TOKEN + try: + if not ACCESS_TOKEN: + ACCESS_TOKEN = Variable.get('access_token') + except KeyError: + set_access_token(SA_FILE_PATH, ACCESS_SCOPES) + + +def _wrapper(*args, **kwargs): + """ + Generic decorator wrapper for checking token and refreshing it. + """ + _check_token() + obj = kwargs.pop("obj") if kwargs.get("obj") else None + headers = kwargs.pop("headers") + request_function = kwargs.pop("request_function") + if not isinstance(headers, dict): + logger.error("Got headers %s" % headers) + raise TypeError + headers["Authorization"] = f"Bearer {ACCESS_TOKEN}" + if obj: # if wrapped function is an object's method + send_request_with_auth = lambda: request_function(obj, headers, *args, **kwargs) + else: + send_request_with_auth = lambda: request_function(headers, *args, **kwargs) + response = send_request_with_auth() + if not isinstance(response, requests.Response): + logger.error("Function %s must return values of type requests.Response. " + "Got %s instead" % (kwargs["rqst_func"], type(response))) + raise TypeError + if not response.ok: + if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): + set_access_token(SA_FILE_PATH, ACCESS_SCOPES) + response = send_request_with_auth() + else: + response.raise_for_status() + return response + + +def refresh_token(request_function): + """ + Wrap a request function and check response. If response's error status code + is about Authorization, refresh token and invoke this function once again. + Expects function: + If response is not ok and not about Authorization, then raises HTTPError + request_func(header: dict, *args, **kwargs) -> requests.Response + Or method: + request_method(self, header: dict, *args, **kwargs) -> requests.Response + """ + is_method = len(request_function.__qualname__.split(".")) > 1 + if is_method: + def wrapper(obj, headers, *args, **kwargs): + return _wrapper(request_function=request_function, obj=obj, headers=headers, *args, + **kwargs) + else: + def wrapper(headers, *args, **kwargs): + return _wrapper(request_function=request_function, headers=headers, *args, **kwargs) + return wrapper diff --git a/src/plugins/operators/process_manifest_op.py b/src/plugins/operators/process_manifest_op.py index 2900281ff3115ea558ac427d70c06d0a840fb837..6e63848521401d0c789e534c2772861cdaf5cf0e 100644 --- a/src/plugins/operators/process_manifest_op.py +++ b/src/plugins/operators/process_manifest_op.py @@ -28,6 +28,7 @@ from urllib.error import HTTPError import requests from airflow.models import BaseOperator, Variable +from libs.refresh_token import refresh_token ACL_DICT = eval(Variable.get("acl")) LEGAL_DICT = eval(Variable.get("legal")) @@ -44,7 +45,8 @@ TIMEOUT = 1 # Set up base logger handler = logging.StreamHandler(sys.stdout) -handler.setFormatter(logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) +handler.setFormatter( + logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) logger = logging.getLogger("Dataload") logger.setLevel(logging.INFO) logger.addHandler(handler) @@ -69,6 +71,7 @@ class FileType(enum.Enum): MANIFEST = enum.auto() WORKPRODUCT = enum.auto() + def dataload(**kwargs): data_conf = kwargs['dag_run'].conf loaded_conf = { @@ -80,6 +83,17 @@ def dataload(**kwargs): return loaded_conf, conf_payload +def create_headers(conf_payload): + partition_id = conf_payload["data-partition-id"] + app_key = conf_payload["AppKey"] + headers = { + 'Content-type': 'application/json', + 'data-partition-id': partition_id, + 'AppKey': app_key + } + return headers + + def generate_id(type_id): """ Generate resource ID @@ -100,6 +114,7 @@ def determine_data_type(raw_resource_type_id): return raw_resource_type_id.split("/")[-1].replace(":", "") \ if raw_resource_type_id is not None else None + # TODO: add comments to functions that implement actions in this function def process_file_items(loaded_conf, conf_payload) -> Tuple[list, list]: file_ids = [] @@ -118,6 +133,7 @@ def process_file_items(loaded_conf, conf_payload) -> Tuple[list, list]: ) return file_list, file_ids + def process_wpc_items(loaded_conf, product_type, file_ids, conf_payload): wpc_ids = [] wpc_list = [] @@ -153,6 +169,7 @@ def process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload) -> list: ] return work_product + def validate_file_type(file_type, data_object): if not file_type: logger.error(f"Error with file {data_object}. Type could not be specified.") @@ -172,8 +189,8 @@ def validate_file(loaded_conf) -> Tuple[FileType, str]: product_type = determine_data_type(data_object["WorkProduct"].get("ResourceTypeID")) validate_file_type(product_type, data_object) if product_type.lower() == "workproduct" and \ - data_object.get("WorkProductComponents") and \ - len(data_object["WorkProductComponents"]) >= 1: + data_object.get("WorkProductComponents") and \ + len(data_object["WorkProductComponents"]) >= 1: product_type = determine_data_type( data_object["WorkProductComponents"][0].get("ResourceTypeID")) validate_file_type(product_type, data_object) @@ -183,6 +200,7 @@ def validate_file(loaded_conf) -> Tuple[FileType, str]: f"Error with file {data_object}. It doesn't have either Manifest or WorkProduct or ResourceTypeID.") sys.exit(2) + def create_kind(data_kind, conf_payload): partition_id = conf_payload.get("data-partition-id", DEFAULT_TENANT) source = conf_payload.get("authority", DEFAULT_SOURCE) @@ -241,7 +259,9 @@ def create_manifest_request_data(loaded_conf: dict, product_type: str): legal_tag = loaded_conf.get("legal_tag") data_object = loaded_conf.get("data_object") data_objects_list = [ - (populate_request_body(data_object["Manifest"], acl, legal_tag, product_type), product_type)] + ( + populate_request_body(data_object["Manifest"], acl, legal_tag, product_type), + product_type)] return data_objects_list @@ -251,19 +271,11 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp return data_objects_list -def send_request(request_data, conf_payload): +@refresh_token +def send_request(headers, request_data): """ Send request to records storage API """ - auth = conf_payload["authorization"] - partition_id = conf_payload["data-partition-id"] - app_key = conf_payload["AppKey"] - headers = { - 'Content-type': 'application/json', - 'data-partition-id': partition_id, - 'Authorization': auth, - 'AppKey': app_key - } logger.error(f"Header {str(headers)}") # loop for implementing retries send process @@ -277,7 +289,7 @@ def send_request(request_data, conf_payload): if response.status_code in DATA_LOAD_OK_RESPONSE_CODES: file_logger.info(",".join(map(str, response.json()["recordIds"]))) - break + return response reason = response.text[:250] logger.error(f"Request error.") @@ -287,6 +299,7 @@ def send_request(request_data, conf_payload): if retry + 1 < retries: if response.status_code in BAD_TOKEN_RESPONSE_CODES: logger.error("Invalid or expired token.") + return response else: time_to_sleep = TIMEOUT @@ -303,23 +316,25 @@ def send_request(request_data, conf_payload): logger.error(f"Request could not be completed.\n" f"Reason: {reason}") sys.exit(2) - return response.json()["recordIds"] def process_manifest(**kwargs): loaded_conf, conf_payload = dataload(**kwargs) file_type, product_type = validate_file(loaded_conf) if file_type is FileType.MANIFEST: - request = create_manifest_request_data(loaded_conf, product_type) + manifest_record = create_manifest_request_data(loaded_conf, product_type) elif file_type is FileType.WORKPRODUCT: file_list, file_ids = process_file_items(loaded_conf, conf_payload) kwargs["ti"].xcom_push(key="file_ids", value=file_ids) wpc_list, wpc_ids = process_wpc_items(loaded_conf, product_type, file_ids, conf_payload) wp_list = process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload) - request = create_workproduct_request_data(loaded_conf, product_type, wp_list, wpc_list, file_list) + manifest_record = create_workproduct_request_data(loaded_conf, product_type, wp_list, + wpc_list, + file_list) else: sys.exit(2) - record_ids = send_request(request, conf_payload) + headers = create_headers(conf_payload) + record_ids = send_request(headers, request_data=manifest_record).json()["recordIds"] kwargs["ti"].xcom_push(key="record_ids", value=record_ids) diff --git a/src/plugins/operators/search_record_id_op.py b/src/plugins/operators/search_record_id_op.py index fefe1809cbe1ee6a41480ad508ee90c21cd88826..90d58d21cdde505950596c3b306d9a17151c3358 100644 --- a/src/plugins/operators/search_record_id_op.py +++ b/src/plugins/operators/search_record_id_op.py @@ -14,28 +14,27 @@ # limitations under the License. -import enum import json import logging import sys -from functools import partial from typing import Tuple import tenacity from airflow.models import BaseOperator, Variable from airflow.utils.decorators import apply_defaults - from hooks import search_http_hook, workflow_hook +from libs.exceptions import RecordsNotSearchableError +from libs.refresh_token import refresh_token # Set up base logger handler = logging.StreamHandler(sys.stdout) -handler.setFormatter(logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) +handler.setFormatter( + logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) logger = logging.getLogger("Dataload") logger.setLevel(logging.INFO) logger.addHandler(handler) - class SearchRecordIdOperator(BaseOperator): """ Operator to search files in SearchService by record ids. @@ -49,32 +48,25 @@ class SearchRecordIdOperator(BaseOperator): FAILED_STATUS = "failed" @apply_defaults - def __init__( self, *args, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.workflow_hook = workflow_hook self.search_hook = search_http_hook - - @staticmethod - def _file_searched(resp, expected_total_count) -> bool: - """Check if search service returns expected totalCount. - The method is used as a callback - """ - data = resp.json() - return data.get("totalCount") == expected_total_count + # the will be set at the beginning of execute method + self.request_body = None + self.expected_total_count = None def get_headers(self, **kwargs) -> dict: data_conf = kwargs['dag_run'].conf # for /submitWithManifest authorization and partition-id are inside Payload field if "Payload" in data_conf: - auth = data_conf["Payload"]["authorization"] partition_id = data_conf["Payload"]["data-partition-id"] else: - auth = data_conf["authorization"] partition_id = data_conf["data-partition-id"] headers = { 'Content-type': 'application/json', 'data-partition-id': partition_id, - 'Authorization': auth, + 'Authorization': "", } return headers @@ -86,35 +78,51 @@ class SearchRecordIdOperator(BaseOperator): query = f"id:({record_ids})" return query, expected_total_count - def search_files(self, **kwargs): - record_ids = kwargs["ti"].xcom_pull(key="record_ids",) + def _create_request_body(self, **kwargs): + record_ids = kwargs["ti"].xcom_pull(key="record_ids", ) if record_ids: query, expected_total_count = self._create_search_query(record_ids) else: logger.error("There are no record ids") sys.exit(2) - headers = self.get_headers(**kwargs) request_body = { "kind": "*:*:*:*", "query": query } - retry_opts = { - "wait": tenacity.wait_exponential(multiplier=5), - "stop": tenacity.stop_after_attempt(5), - "retry": tenacity.retry_if_not_result( - partial(self._file_searched, expected_total_count=expected_total_count) + return request_body, expected_total_count + + def _is_record_searchable(self, resp) -> bool: + """ + Check if search service returns expected totalCount of records. + """ + data = resp.json() + return data.get("totalCount") == self.expected_total_count + + @tenacity.retry(tenacity.wait_exponential(multiplier=5), tenacity.stop_after_attempt(5)) + @refresh_token + def search_files(self, headers, **kwargs): + if self.request_body: + response = self.search_hook.run( + endpoint=Variable.get("search_query_ep"), + headers=headers, + data=json.dumps(self.request_body), + extra_options={"check_response": False} ) - } - self.search_hook.run_with_advanced_retry( - endpoint=Variable.get("search_query_ep"), - headers=headers, - data=json.dumps(request_body), - _retry_args=retry_opts - ) + if not self._is_record_searchable(response): + logger.error("Expected amount (%s) of records not found." % + self.expected_total_count + ) + raise RecordsNotSearchableError + return response + else: + logger.error("There is an error in header or in request body") + sys.exit(2) def execute(self, context): """Execute update workflow status. If status assumed to be FINISHED then we check whether proceed files are searchable or not. If they are then update status FINISHED else FAILED """ - self.search_files(**context) + self.request_body, self.expected_total_count = self._create_request_body(**context) + headers = self.get_headers(**context) + self.search_files(headers, **context) diff --git a/src/plugins/operators/update_status_op.py b/src/plugins/operators/update_status_op.py index 5ae50498ad42a1936aab5a045b61186a691086f9..1acdc6d01991bfaf85308ac33cd986ac258f895c 100644 --- a/src/plugins/operators/update_status_op.py +++ b/src/plugins/operators/update_status_op.py @@ -23,20 +23,20 @@ from functools import partial import tenacity from airflow.models import BaseOperator, Variable from airflow.utils.decorators import apply_defaults - from hooks import search_http_hook, workflow_hook +from libs.exceptions import PipelineFailedError +from libs.refresh_token import refresh_token # Set up base logger handler = logging.StreamHandler(sys.stdout) -handler.setFormatter(logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) +handler.setFormatter( + logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) logger = logging.getLogger("Dataload") logger.setLevel(logging.INFO) logger.addHandler(handler) - class UpdateStatusOperator(BaseOperator): - ui_color = '#10ECAA' ui_fgcolor = '#000000' @@ -50,7 +50,7 @@ class UpdateStatusOperator(BaseOperator): FAILED = enum.auto() @apply_defaults - def __init__( self, *args, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.workflow_hook = workflow_hook self.search_hook = search_http_hook @@ -67,15 +67,12 @@ class UpdateStatusOperator(BaseOperator): data_conf = kwargs['dag_run'].conf # for /submitWithManifest authorization and partition-id are inside Payload field if "Payload" in data_conf: - auth = data_conf["Payload"]["authorization"] partition_id = data_conf["Payload"]["data-partition-id"] else: - auth = data_conf["authorization"] partition_id = data_conf["data-partition-id"] headers = { 'Content-type': 'application/json', 'data-partition-id': partition_id, - 'Authorization': auth, } return headers @@ -114,8 +111,9 @@ class UpdateStatusOperator(BaseOperator): def previous_ti_statuses(self, context): dagrun = context['ti'].get_dagrun() - failed_ti, success_ti = dagrun.get_task_instances(state='failed'), dagrun.get_task_instances(state='success') - if not failed_ti and not success_ti: # There is no prev task so it can't have been failed + failed_ti, success_ti = dagrun.get_task_instances( + state='failed'), dagrun.get_task_instances(state='success') + if not failed_ti and not success_ti: # There is no prev task so it can't have been failed logger.info("There are no tasks before this one. So it has status RUNNING") return self.prev_ti_state.NONE if failed_ti: @@ -138,22 +136,25 @@ class UpdateStatusOperator(BaseOperator): If status assumed to be FINISHED then we check whether proceed files are searchable or not. If they are then update status FINISHED else FAILED """ - self.update_status_rqst(self.status, **context) + headers = self.get_headers(**context) + self.update_status_request(headers, self.status, **context) if self.status == self.FAILED_STATUS: - raise Exception("Dag failed") + raise PipelineFailedError("Dag failed") - def update_status_rqst(self, status, **kwargs): + @refresh_token + def update_status_request(self, headers, status, **kwargs): data_conf = kwargs['dag_run'].conf logger.info(f"Got dataconf {data_conf}") workflow_id = data_conf["WorkflowID"] - headers = self.get_headers(**kwargs) request_body = { "WorkflowID": workflow_id, "Status": status } logger.info(f" Sending request '{status}'") - self.workflow_hook.run( + response = self.workflow_hook.run( endpoint=Variable.get("update_status_ep"), data=json.dumps(request_body), - headers=headers + headers=headers, + extra_options={"check_response": False} ) + return response diff --git a/tests/end-to-end-tests/mock-external-apis/app.py b/tests/end-to-end-tests/mock-external-apis/app.py index e88de38dd5f51f4a9a32a235d14d0749055ccc1d..23f54021fc968191b985ad4dbd95a9e351d2594e 100644 --- a/tests/end-to-end-tests/mock-external-apis/app.py +++ b/tests/end-to-end-tests/mock-external-apis/app.py @@ -14,8 +14,7 @@ # limitations under the License. -from flask import Flask, url_for, request -from flask import json +from flask import Flask, json, request, url_for OSDU_INGEST_SUCCES_FIFO = "/tmp/osdu_ingest_success" OSDU_INGEST_FAILED_FIFO = "/tmp/osdu_ingest_failed" diff --git a/tests/plugin-unit-tests/data/__init__.py b/tests/plugin-unit-tests/data/__init__.py index 9df114afccf05e92222bab8f16892bbab9c21fc8..e9d8ac876ee1c3ce060bf9b178727657bf6f2758 100644 --- a/tests/plugin-unit-tests/data/__init__.py +++ b/tests/plugin-unit-tests/data/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .process_manifest_op import * \ No newline at end of file +from .process_manifest_op import * diff --git a/tests/plugin-unit-tests/test_process_manifest_op.py b/tests/plugin-unit-tests/test_process_manifest_op.py index 56717e2ff5c3a96d1e18fd9605d6f99600d0c4ef..3c516f4a801ed3c3851cfff00676d854b8ba1a89 100644 --- a/tests/plugin-unit-tests/test_process_manifest_op.py +++ b/tests/plugin-unit-tests/test_process_manifest_op.py @@ -14,45 +14,47 @@ # limitations under the License. -import re import os -import pytest +import re import sys +import pytest + sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") -from operators import process_manifest_op as p_m_op from data import process_manifest_op as test_data +from operators import process_manifest_op + @pytest.mark.parametrize( - "test_input, expected", - [ - ("srn:type:work-product/WellLog:", "WellLog"), - ("srn:type:file/las2:", "las2"), - ] + "test_input, expected", + [ + ("srn:type:work-product/WellLog:", "WellLog"), + ("srn:type:file/las2:", "las2"), + ] ) def test_determine_data_type(test_input, expected): - data_type = p_m_op.determine_data_type(test_input) - assert data_type == expected + data_type = process_manifest_op.determine_data_type(test_input) + assert data_type == expected @pytest.mark.parametrize( - "data_type, loaded_conf, conf_payload, expected_file_result", - [ - ("las2", - test_data.LOADED_CONF, - test_data.CONF_PAYLOAD, - test_data.PROCESS_FILE_ITEMS_RESULT) - ] + "data_type, loaded_conf, conf_payload, expected_file_result", + [ + ("las2", + test_data.LOADED_CONF, + test_data.CONF_PAYLOAD, + test_data.PROCESS_FILE_ITEMS_RESULT) + ] ) def test_process_file_items(data_type, loaded_conf, conf_payload, expected_file_result): - file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:") - expected_file_list = expected_file_result[0] - file_list, file_ids = p_m_op.process_file_items(loaded_conf, conf_payload) - for i in file_ids: - assert file_id_regex.match(i) - - for i in file_list: - assert file_id_regex.match(i[0]["data"]["ResourceID"]) - i[0]["data"]["ResourceID"] = "" - assert file_list == expected_file_list + file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:") + expected_file_list = expected_file_result[0] + file_list, file_ids = process_manifest_op.process_file_items(loaded_conf, conf_payload) + for i in file_ids: + assert file_id_regex.match(i) + + for i in file_list: + assert file_id_regex.match(i[0]["data"]["ResourceID"]) + i[0]["data"]["ResourceID"] = "" + assert file_list == expected_file_list diff --git a/tests/set_airflow_env.sh b/tests/set_airflow_env.sh index ad293724fc6ff1718f84138d70eb8af0c938f7aa..21be656adb59f203d1351dc5b68c60b0d3047755 100755 --- a/tests/set_airflow_env.sh +++ b/tests/set_airflow_env.sh @@ -51,6 +51,9 @@ airflow variables -s update_status_ep wf/us airflow variables -s search_url $LOCALHOST airflow variables -s dataload_config_path $DATALOAD_CONFIG_PATH airflow variables -s search_query_ep sr/qr +airflow variables -s access_token test +airflow variables -s "sa-file-osdu" "test" + airflow connections -a --conn_id search --conn_uri $SEARCH_CONN_ID airflow connections -a --conn_id workflow --conn_uri $WORKFLOW_CONN_ID airflow connections -a --conn_id google_cloud_storage --conn_uri $WORKFLOW_CONN_ID diff --git a/tests/test_dags.py b/tests/test_dags.py index c0200186d17eaaee0b06ef49ac46ca96a79e04a5..ffd6a056f05130d54d1b9327ebc19b341e622d8c 100644 --- a/tests/test_dags.py +++ b/tests/test_dags.py @@ -17,28 +17,28 @@ import enum import subprocess import time + class DagStatus(enum.Enum): - RUNNING = enum.auto() - FAILED = enum.auto() - FINISHED = enum.auto() + RUNNING = "running" + FAILED = "failed" + FINISHED = "finished" + OSDU_INGEST_SUCCESS_SH = "/mock-server/./test-osdu-ingest-success.sh" OSDU_INGEST_FAIL_SH = "/mock-server/./test-osdu-ingest-fail.sh" DEFAULT_INGEST_SUCCESS_SH = "/mock-server/./test-default-ingest-success.sh" DEFAULT_INGEST_FAIL_SH = "/mock-server/./test-default-ingest-fail.sh" -with open("/tmp/osdu_ingest_result", "w") as f: - f.close() - subprocess.run(f"/bin/bash -c 'airflow scheduler > /dev/null 2>&1 &'", shell=True) -def check_dag_status(dag_name): + +def check_dag_status(dag_name: str) -> DagStatus: time.sleep(5) output = subprocess.getoutput(f'airflow list_dag_runs {dag_name}') if "failed" in output: print(dag_name) print(output) - return DagStatus.FAILED + return DagStatus.FAILED if "running" in output: return DagStatus.RUNNING print(dag_name) @@ -46,32 +46,17 @@ def check_dag_status(dag_name): return DagStatus.FINISHED -def test_dag_success(dag_name, script): - print(f"Test {dag_name} success") - subprocess.run(f"{script}", shell=True) - while True: - dag_status = check_dag_status(dag_name) - if dag_status is DagStatus.RUNNING: - continue - elif dag_status is DagStatus.FINISHED: - return - else: - raise Exception(f"Error {dag_name} supposed to be finished") - -def test_dag_fail(dag_name, script): +def test_dag_execution_result(dag_name: str, script: str, expected_status: DagStatus): subprocess.run(f"{script}", shell=True) - print(f"Expecting {dag_name} fail") + print(f"Expecting {dag_name} to be {expected_status.value}") while True: dag_status = check_dag_status(dag_name) - if dag_status is DagStatus.RUNNING: - continue - elif dag_status is DagStatus.FAILED: - return - else: - raise Exception(f"Error {dag_name} supposed to be failed") + if dag_status is not DagStatus.RUNNING: + break + assert dag_status is expected_status, f"Error {dag_name} supposed to be {expected_status.value}" -test_dag_success("Osdu_ingest", OSDU_INGEST_SUCCESS_SH) -test_dag_fail("Osdu_ingest", OSDU_INGEST_FAIL_SH) -test_dag_success("Default_ingest", DEFAULT_INGEST_SUCCESS_SH) -test_dag_fail("Default_ingest", DEFAULT_INGEST_FAIL_SH) +test_dag_execution_result("Osdu_ingest", OSDU_INGEST_SUCCESS_SH, DagStatus.FINISHED) +test_dag_execution_result("Osdu_ingest", OSDU_INGEST_FAIL_SH, DagStatus.FAILED) +test_dag_execution_result("Default_ingest", DEFAULT_INGEST_SUCCESS_SH, DagStatus.FINISHED) +test_dag_execution_result("Default_ingest", DEFAULT_INGEST_FAIL_SH, DagStatus.FAILED) diff --git a/tests/unit_tests.sh b/tests/unit_tests.sh index dc6e07c88fc480bc627a4fea63206d900f294ea2..ae3c686de750d7e09b2df1eda8042ba6cda29cc9 100644 --- a/tests/unit_tests.sh +++ b/tests/unit_tests.sh @@ -1,3 +1,4 @@ +pip uninstall enum34 -y pip install pytest pip install --upgrade google-api-python-client chmod +x tests/set_airflow_env.sh