From 710b8152fdbc8331fbe2a0310ef60937efeea3a2 Mon Sep 17 00:00:00 2001 From: yan <yan_sushchynski@epam.com> Date: Mon, 26 Oct 2020 16:47:58 +0300 Subject: [PATCH] GONRG-790: Change naming. Add retry to get_token. --- src/dags/libs/create_records.py | 6 +-- src/plugins/libs/exceptions.py | 35 +++++++++++++ src/plugins/libs/refresh_token.py | 40 ++++++++++----- src/plugins/operators/search_record_id_op.py | 17 +++---- src/plugins/operators/update_status_op.py | 15 ++---- .../test_process_manifest_op.py | 50 +++++++++---------- tests/test_dags.py | 48 ++++++------------ 7 files changed, 115 insertions(+), 96 deletions(-) create mode 100644 src/plugins/libs/exceptions.py diff --git a/src/dags/libs/create_records.py b/src/dags/libs/create_records.py index 79074fc..b8446cd 100644 --- a/src/dags/libs/create_records.py +++ b/src/dags/libs/create_records.py @@ -27,7 +27,6 @@ from osdu_api.storage.record_client import RecordClient logger = logging.getLogger() - ACL_DICT = eval(Variable.get("acl")) LEGAL_DICT = eval(Variable.get("legal")) @@ -40,7 +39,7 @@ DEFAULT_VERSION = config.get("DEFAULTS", "kind_version") @refresh_token -def send_create_update_record_request(headers, record_client, record): +def create_update_record_request(headers, record_client, record): resp = record_client.create_update_records([record], headers.items()) return resp @@ -76,14 +75,13 @@ def create_records(**kwargs): headers = { "content-type": "application/json", "slb-data-partition-id": data_conf.get("partition-id", DEFAULT_SOURCE), - # "Authorization": f"{auth}", "AppKey": data_conf.get("app-key", "") } record_client = RecordClient() record_client.data_partition_id = data_conf.get( "partition-id", DEFAULT_SOURCE) - resp = send_create_update_record_request(headers, record_client, record) + resp = create_update_record_request(headers, record_client, record) logger.info(f"Response: {resp.text}") kwargs["ti"].xcom_push(key="record_ids", value=resp.json()["recordIds"]) return {"response_status": resp.status_code} diff --git a/src/plugins/libs/exceptions.py b/src/plugins/libs/exceptions.py new file mode 100644 index 0000000..446ec6e --- /dev/null +++ b/src/plugins/libs/exceptions.py @@ -0,0 +1,35 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class RecordsNotSearchableError(Exception): + """ + Raise when expected totalCount of records differs from actual one. + """ + pass + + +class RefreshSATokenError(Exception): + """ + Raise when token is empty after attempt to get credentials from service account file. + """ + pass + + +class PipelineFailedError(Exception): + """ + Raise when pipeline failed. + """ + pass diff --git a/src/plugins/libs/refresh_token.py b/src/plugins/libs/refresh_token.py index 1f4e8ec..384e475 100644 --- a/src/plugins/libs/refresh_token.py +++ b/src/plugins/libs/refresh_token.py @@ -21,6 +21,7 @@ import requests from airflow.models import Variable from google.auth.transport.requests import Request from google.oauth2 import service_account +from libs.exceptions import RefreshSATokenError from tenacity import retry, stop_after_attempt ACCESS_TOKEN = None @@ -39,16 +40,24 @@ SA_FILE_PATH = Variable.get("sa-file-osdu") ACCESS_SCOPES = ['openid', 'email', 'profile'] +@retry(stop=stop_after_attempt(RETRIES)) def get_access_token(sa_file: str, scopes: list) -> str: """ Refresh access token. """ - credentials = service_account.Credentials.from_service_account_file( - sa_file, scopes=scopes) + try: + credentials = service_account.Credentials.from_service_account_file( + sa_file, scopes=scopes) + except ValueError as e: + logger.error("SA file has bad format.") + raise e logger.info("Refresh token.") credentials.refresh(Request()) - logger.info("Token refreshed.") token = credentials.token + if credentials.token is None: + logger.error("Can't refresh token using SA-file. Token is empty.") + raise RefreshSATokenError + logger.info("Token refreshed.") return token @@ -81,40 +90,45 @@ def _wrapper(*args, **kwargs): _check_token() obj = kwargs.pop("obj") if kwargs.get("obj") else None headers = kwargs.pop("headers") - rqst_func = kwargs.pop("rqst_func") + request_function = kwargs.pop("request_function") if not isinstance(headers, dict): logger.error("Got headers %s" % headers) raise TypeError headers["Authorization"] = f"Bearer {ACCESS_TOKEN}" if obj: # if wrapped function is an object's method - send_request_with_auth = lambda: rqst_func(obj, headers, *args, **kwargs) + send_request_with_auth = lambda: request_function(obj, headers, *args, **kwargs) else: - send_request_with_auth = lambda: rqst_func(headers, *args, **kwargs) + send_request_with_auth = lambda: request_function(headers, *args, **kwargs) response = send_request_with_auth() if not isinstance(response, requests.Response): logger.error("Function %s must return values of type requests.Response. " "Got %s instead" % (kwargs["rqst_func"], type(response))) raise TypeError - if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): - set_access_token(SA_FILE_PATH, ACCESS_SCOPES) - response = send_request_with_auth() + if not response.ok: + if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): + set_access_token(SA_FILE_PATH, ACCESS_SCOPES) + response = send_request_with_auth() + else: + response.raise_for_status() return response -def refresh_token(rqst_func): +def refresh_token(request_function): """ Wrap a request function and check response. If response's error status code is about Authorization, refresh token and invoke this function once again. Expects function: + If response is not ok and not about Authorization, then raises HTTPError request_func(header: dict, *args, **kwargs) -> requests.Response Or method: request_method(self, header: dict, *args, **kwargs) -> requests.Response """ - is_method = len(rqst_func.__qualname__.split(".")) > 1 + is_method = len(request_function.__qualname__.split(".")) > 1 if is_method: def wrapper(obj, headers, *args, **kwargs): - return _wrapper(rqst_func=rqst_func, obj=obj, headers=headers, *args, **kwargs) + return _wrapper(request_function=request_function, obj=obj, headers=headers, *args, + **kwargs) else: def wrapper(headers, *args, **kwargs): - return _wrapper(rqst_func=rqst_func, headers=headers, *args, **kwargs) + return _wrapper(request_function=request_function, headers=headers, *args, **kwargs) return wrapper diff --git a/src/plugins/operators/search_record_id_op.py b/src/plugins/operators/search_record_id_op.py index 63fd065..90d58d2 100644 --- a/src/plugins/operators/search_record_id_op.py +++ b/src/plugins/operators/search_record_id_op.py @@ -14,7 +14,6 @@ # limitations under the License. -import http import json import logging import sys @@ -24,6 +23,7 @@ import tenacity from airflow.models import BaseOperator, Variable from airflow.utils.decorators import apply_defaults from hooks import search_http_hook, workflow_hook +from libs.exceptions import RecordsNotSearchableError from libs.refresh_token import refresh_token # Set up base logger @@ -60,10 +60,8 @@ class SearchRecordIdOperator(BaseOperator): data_conf = kwargs['dag_run'].conf # for /submitWithManifest authorization and partition-id are inside Payload field if "Payload" in data_conf: - # auth = data_conf["Payload"]["authorization"] partition_id = data_conf["Payload"]["data-partition-id"] else: - # auth = data_conf["authorization"] partition_id = data_conf["data-partition-id"] headers = { 'Content-type': 'application/json', @@ -93,9 +91,9 @@ class SearchRecordIdOperator(BaseOperator): } return request_body, expected_total_count - def _file_searched(self, resp) -> bool: - """Check if search service returns expected totalCount. - The method is used as a callback + def _is_record_searchable(self, resp) -> bool: + """ + Check if search service returns expected totalCount of records. """ data = resp.json() return data.get("totalCount") == self.expected_total_count @@ -110,14 +108,11 @@ class SearchRecordIdOperator(BaseOperator): data=json.dumps(self.request_body), extra_options={"check_response": False} ) - if not response.ok and response.status_code not in ( - http.HTTPStatus.FORBIDDEN, http.HTTPStatus.UNAUTHORIZED): - raise Exception("Error %s, text: %s." % (response.status_code, response.text)) - if not self._file_searched(response): + if not self._is_record_searchable(response): logger.error("Expected amount (%s) of records not found." % self.expected_total_count ) - raise Exception + raise RecordsNotSearchableError return response else: logger.error("There is an error in header or in request body") diff --git a/src/plugins/operators/update_status_op.py b/src/plugins/operators/update_status_op.py index 22f4dd1..1acdc6d 100644 --- a/src/plugins/operators/update_status_op.py +++ b/src/plugins/operators/update_status_op.py @@ -15,17 +15,16 @@ import enum -import http import json import logging import sys from functools import partial -import requests import tenacity from airflow.models import BaseOperator, Variable from airflow.utils.decorators import apply_defaults from hooks import search_http_hook, workflow_hook +from libs.exceptions import PipelineFailedError from libs.refresh_token import refresh_token # Set up base logger @@ -68,15 +67,12 @@ class UpdateStatusOperator(BaseOperator): data_conf = kwargs['dag_run'].conf # for /submitWithManifest authorization and partition-id are inside Payload field if "Payload" in data_conf: - # auth = data_conf["Payload"]["authorization"] partition_id = data_conf["Payload"]["data-partition-id"] else: - # auth = data_conf["authorization"] partition_id = data_conf["data-partition-id"] headers = { 'Content-type': 'application/json', 'data-partition-id': partition_id, - # 'Authorization': auth, } return headers @@ -141,12 +137,12 @@ class UpdateStatusOperator(BaseOperator): If they are then update status FINISHED else FAILED """ headers = self.get_headers(**context) - self.update_status_rqst(headers, self.status, **context) + self.update_status_request(headers, self.status, **context) if self.status == self.FAILED_STATUS: - raise Exception("Dag failed") + raise PipelineFailedError("Dag failed") @refresh_token - def update_status_rqst(self, headers, status, **kwargs): + def update_status_request(self, headers, status, **kwargs): data_conf = kwargs['dag_run'].conf logger.info(f"Got dataconf {data_conf}") workflow_id = data_conf["WorkflowID"] @@ -161,7 +157,4 @@ class UpdateStatusOperator(BaseOperator): headers=headers, extra_options={"check_response": False} ) - if not response.ok and response.status_code not in ( - http.HTTPStatus.UNAUTHORIZED, http.HTTPStatus.FORBIDDEN): - raise Exception("Error %s, text: %s." % (response.status_code, response.text)) return response diff --git a/tests/plugin-unit-tests/test_process_manifest_op.py b/tests/plugin-unit-tests/test_process_manifest_op.py index 8b1ee50..3c516f4 100644 --- a/tests/plugin-unit-tests/test_process_manifest_op.py +++ b/tests/plugin-unit-tests/test_process_manifest_op.py @@ -23,38 +23,38 @@ import pytest sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") from data import process_manifest_op as test_data -from operators import process_manifest_op as p_m_op +from operators import process_manifest_op @pytest.mark.parametrize( - "test_input, expected", - [ - ("srn:type:work-product/WellLog:", "WellLog"), - ("srn:type:file/las2:", "las2"), - ] + "test_input, expected", + [ + ("srn:type:work-product/WellLog:", "WellLog"), + ("srn:type:file/las2:", "las2"), + ] ) def test_determine_data_type(test_input, expected): - data_type = p_m_op.determine_data_type(test_input) - assert data_type == expected + data_type = process_manifest_op.determine_data_type(test_input) + assert data_type == expected @pytest.mark.parametrize( - "data_type, loaded_conf, conf_payload, expected_file_result", - [ - ("las2", - test_data.LOADED_CONF, - test_data.CONF_PAYLOAD, - test_data.PROCESS_FILE_ITEMS_RESULT) - ] + "data_type, loaded_conf, conf_payload, expected_file_result", + [ + ("las2", + test_data.LOADED_CONF, + test_data.CONF_PAYLOAD, + test_data.PROCESS_FILE_ITEMS_RESULT) + ] ) def test_process_file_items(data_type, loaded_conf, conf_payload, expected_file_result): - file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:") - expected_file_list = expected_file_result[0] - file_list, file_ids = p_m_op.process_file_items(loaded_conf, conf_payload) - for i in file_ids: - assert file_id_regex.match(i) - - for i in file_list: - assert file_id_regex.match(i[0]["data"]["ResourceID"]) - i[0]["data"]["ResourceID"] = "" - assert file_list == expected_file_list + file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:") + expected_file_list = expected_file_result[0] + file_list, file_ids = process_manifest_op.process_file_items(loaded_conf, conf_payload) + for i in file_ids: + assert file_id_regex.match(i) + + for i in file_list: + assert file_id_regex.match(i[0]["data"]["ResourceID"]) + i[0]["data"]["ResourceID"] = "" + assert file_list == expected_file_list diff --git a/tests/test_dags.py b/tests/test_dags.py index 648bb3a..ffd6a05 100644 --- a/tests/test_dags.py +++ b/tests/test_dags.py @@ -19,27 +19,26 @@ import time class DagStatus(enum.Enum): - RUNNING = enum.auto() - FAILED = enum.auto() - FINISHED = enum.auto() + RUNNING = "running" + FAILED = "failed" + FINISHED = "finished" + OSDU_INGEST_SUCCESS_SH = "/mock-server/./test-osdu-ingest-success.sh" OSDU_INGEST_FAIL_SH = "/mock-server/./test-osdu-ingest-fail.sh" DEFAULT_INGEST_SUCCESS_SH = "/mock-server/./test-default-ingest-success.sh" DEFAULT_INGEST_FAIL_SH = "/mock-server/./test-default-ingest-fail.sh" -with open("/tmp/osdu_ingest_result", "w") as f: - f.close() - subprocess.run(f"/bin/bash -c 'airflow scheduler > /dev/null 2>&1 &'", shell=True) -def check_dag_status(dag_name): + +def check_dag_status(dag_name: str) -> DagStatus: time.sleep(5) output = subprocess.getoutput(f'airflow list_dag_runs {dag_name}') if "failed" in output: print(dag_name) print(output) - return DagStatus.FAILED + return DagStatus.FAILED if "running" in output: return DagStatus.RUNNING print(dag_name) @@ -47,32 +46,17 @@ def check_dag_status(dag_name): return DagStatus.FINISHED -def test_dag_success(dag_name, script): - print(f"Test {dag_name} success") - subprocess.run(f"{script}", shell=True) - while True: - dag_status = check_dag_status(dag_name) - if dag_status is DagStatus.RUNNING: - continue - elif dag_status is DagStatus.FINISHED: - return - else: - raise Exception(f"Error {dag_name} supposed to be finished") - -def test_dag_fail(dag_name, script): +def test_dag_execution_result(dag_name: str, script: str, expected_status: DagStatus): subprocess.run(f"{script}", shell=True) - print(f"Expecting {dag_name} fail") + print(f"Expecting {dag_name} to be {expected_status.value}") while True: dag_status = check_dag_status(dag_name) - if dag_status is DagStatus.RUNNING: - continue - elif dag_status is DagStatus.FAILED: - return - else: - raise Exception(f"Error {dag_name} supposed to be failed") + if dag_status is not DagStatus.RUNNING: + break + assert dag_status is expected_status, f"Error {dag_name} supposed to be {expected_status.value}" -test_dag_success("Osdu_ingest", OSDU_INGEST_SUCCESS_SH) -test_dag_fail("Osdu_ingest", OSDU_INGEST_FAIL_SH) -test_dag_success("Default_ingest", DEFAULT_INGEST_SUCCESS_SH) -test_dag_fail("Default_ingest", DEFAULT_INGEST_FAIL_SH) +test_dag_execution_result("Osdu_ingest", OSDU_INGEST_SUCCESS_SH, DagStatus.FINISHED) +test_dag_execution_result("Osdu_ingest", OSDU_INGEST_FAIL_SH, DagStatus.FAILED) +test_dag_execution_result("Default_ingest", DEFAULT_INGEST_SUCCESS_SH, DagStatus.FINISHED) +test_dag_execution_result("Default_ingest", DEFAULT_INGEST_FAIL_SH, DagStatus.FAILED) -- GitLab