From 710b8152fdbc8331fbe2a0310ef60937efeea3a2 Mon Sep 17 00:00:00 2001
From: yan <yan_sushchynski@epam.com>
Date: Mon, 26 Oct 2020 16:47:58 +0300
Subject: [PATCH] GONRG-790: Change naming. Add retry to get_token.

---
 src/dags/libs/create_records.py               |  6 +--
 src/plugins/libs/exceptions.py                | 35 +++++++++++++
 src/plugins/libs/refresh_token.py             | 40 ++++++++++-----
 src/plugins/operators/search_record_id_op.py  | 17 +++----
 src/plugins/operators/update_status_op.py     | 15 ++----
 .../test_process_manifest_op.py               | 50 +++++++++----------
 tests/test_dags.py                            | 48 ++++++------------
 7 files changed, 115 insertions(+), 96 deletions(-)
 create mode 100644 src/plugins/libs/exceptions.py

diff --git a/src/dags/libs/create_records.py b/src/dags/libs/create_records.py
index 79074fc..b8446cd 100644
--- a/src/dags/libs/create_records.py
+++ b/src/dags/libs/create_records.py
@@ -27,7 +27,6 @@ from osdu_api.storage.record_client import RecordClient
 
 logger = logging.getLogger()
 
-
 ACL_DICT = eval(Variable.get("acl"))
 LEGAL_DICT = eval(Variable.get("legal"))
 
@@ -40,7 +39,7 @@ DEFAULT_VERSION = config.get("DEFAULTS", "kind_version")
 
 
 @refresh_token
-def send_create_update_record_request(headers, record_client, record):
+def create_update_record_request(headers, record_client, record):
     resp = record_client.create_update_records([record], headers.items())
     return resp
 
@@ -76,14 +75,13 @@ def create_records(**kwargs):
     headers = {
         "content-type": "application/json",
         "slb-data-partition-id": data_conf.get("partition-id", DEFAULT_SOURCE),
-        # "Authorization": f"{auth}",
         "AppKey": data_conf.get("app-key", "")
     }
 
     record_client = RecordClient()
     record_client.data_partition_id = data_conf.get(
         "partition-id", DEFAULT_SOURCE)
-    resp = send_create_update_record_request(headers, record_client, record)
+    resp = create_update_record_request(headers, record_client, record)
     logger.info(f"Response: {resp.text}")
     kwargs["ti"].xcom_push(key="record_ids", value=resp.json()["recordIds"])
     return {"response_status": resp.status_code}
diff --git a/src/plugins/libs/exceptions.py b/src/plugins/libs/exceptions.py
new file mode 100644
index 0000000..446ec6e
--- /dev/null
+++ b/src/plugins/libs/exceptions.py
@@ -0,0 +1,35 @@
+#  Copyright 2020 Google LLC
+#  Copyright 2020 EPAM Systems
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+
+class RecordsNotSearchableError(Exception):
+    """
+    Raise when expected totalCount of records differs from actual one.
+    """
+    pass
+
+
+class RefreshSATokenError(Exception):
+    """
+    Raise when token is empty after attempt to get credentials from service account file.
+    """
+    pass
+
+
+class PipelineFailedError(Exception):
+    """
+    Raise when pipeline failed.
+    """
+    pass
diff --git a/src/plugins/libs/refresh_token.py b/src/plugins/libs/refresh_token.py
index 1f4e8ec..384e475 100644
--- a/src/plugins/libs/refresh_token.py
+++ b/src/plugins/libs/refresh_token.py
@@ -21,6 +21,7 @@ import requests
 from airflow.models import Variable
 from google.auth.transport.requests import Request
 from google.oauth2 import service_account
+from libs.exceptions import RefreshSATokenError
 from tenacity import retry, stop_after_attempt
 
 ACCESS_TOKEN = None
@@ -39,16 +40,24 @@ SA_FILE_PATH = Variable.get("sa-file-osdu")
 ACCESS_SCOPES = ['openid', 'email', 'profile']
 
 
+@retry(stop=stop_after_attempt(RETRIES))
 def get_access_token(sa_file: str, scopes: list) -> str:
     """
     Refresh access token.
     """
-    credentials = service_account.Credentials.from_service_account_file(
-        sa_file, scopes=scopes)
+    try:
+        credentials = service_account.Credentials.from_service_account_file(
+            sa_file, scopes=scopes)
+    except ValueError as e:
+        logger.error("SA file has bad format.")
+        raise e
     logger.info("Refresh token.")
     credentials.refresh(Request())
-    logger.info("Token refreshed.")
     token = credentials.token
+    if credentials.token is None:
+        logger.error("Can't refresh token using SA-file. Token is empty.")
+        raise RefreshSATokenError
+    logger.info("Token refreshed.")
     return token
 
 
@@ -81,40 +90,45 @@ def _wrapper(*args, **kwargs):
     _check_token()
     obj = kwargs.pop("obj") if kwargs.get("obj") else None
     headers = kwargs.pop("headers")
-    rqst_func = kwargs.pop("rqst_func")
+    request_function = kwargs.pop("request_function")
     if not isinstance(headers, dict):
         logger.error("Got headers %s" % headers)
         raise TypeError
     headers["Authorization"] = f"Bearer {ACCESS_TOKEN}"
     if obj:  # if wrapped function is an object's method
-        send_request_with_auth = lambda: rqst_func(obj, headers, *args, **kwargs)
+        send_request_with_auth = lambda: request_function(obj, headers, *args, **kwargs)
     else:
-        send_request_with_auth = lambda: rqst_func(headers, *args, **kwargs)
+        send_request_with_auth = lambda: request_function(headers, *args, **kwargs)
     response = send_request_with_auth()
     if not isinstance(response, requests.Response):
         logger.error("Function %s must return values of type requests.Response. "
                      "Got %s instead" % (kwargs["rqst_func"], type(response)))
         raise TypeError
-    if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
-        set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
-        response = send_request_with_auth()
+    if not response.ok:
+        if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
+            set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
+            response = send_request_with_auth()
+        else:
+            response.raise_for_status()
     return response
 
 
-def refresh_token(rqst_func):
+def refresh_token(request_function):
     """
     Wrap a request function and check response. If response's error status code
     is about Authorization, refresh token and invoke this function once again.
     Expects function:
+    If response is not ok and not about Authorization, then raises HTTPError
     request_func(header: dict, *args, **kwargs) -> requests.Response
     Or method:
     request_method(self, header: dict, *args, **kwargs) -> requests.Response
     """
-    is_method = len(rqst_func.__qualname__.split(".")) > 1
+    is_method = len(request_function.__qualname__.split(".")) > 1
     if is_method:
         def wrapper(obj, headers, *args, **kwargs):
-            return _wrapper(rqst_func=rqst_func, obj=obj, headers=headers, *args, **kwargs)
+            return _wrapper(request_function=request_function, obj=obj, headers=headers, *args,
+                            **kwargs)
     else:
         def wrapper(headers, *args, **kwargs):
-            return _wrapper(rqst_func=rqst_func, headers=headers, *args, **kwargs)
+            return _wrapper(request_function=request_function, headers=headers, *args, **kwargs)
     return wrapper
diff --git a/src/plugins/operators/search_record_id_op.py b/src/plugins/operators/search_record_id_op.py
index 63fd065..90d58d2 100644
--- a/src/plugins/operators/search_record_id_op.py
+++ b/src/plugins/operators/search_record_id_op.py
@@ -14,7 +14,6 @@
 #  limitations under the License.
 
 
-import http
 import json
 import logging
 import sys
@@ -24,6 +23,7 @@ import tenacity
 from airflow.models import BaseOperator, Variable
 from airflow.utils.decorators import apply_defaults
 from hooks import search_http_hook, workflow_hook
+from libs.exceptions import RecordsNotSearchableError
 from libs.refresh_token import refresh_token
 
 # Set up base logger
@@ -60,10 +60,8 @@ class SearchRecordIdOperator(BaseOperator):
         data_conf = kwargs['dag_run'].conf
         # for /submitWithManifest authorization and partition-id are inside Payload field
         if "Payload" in data_conf:
-            # auth = data_conf["Payload"]["authorization"]
             partition_id = data_conf["Payload"]["data-partition-id"]
         else:
-            # auth = data_conf["authorization"]
             partition_id = data_conf["data-partition-id"]
         headers = {
             'Content-type': 'application/json',
@@ -93,9 +91,9 @@ class SearchRecordIdOperator(BaseOperator):
         }
         return request_body, expected_total_count
 
-    def _file_searched(self, resp) -> bool:
-        """Check if search service returns expected totalCount.
-        The method is used  as a callback
+    def _is_record_searchable(self, resp) -> bool:
+        """
+        Check if search service returns expected totalCount of records.
         """
         data = resp.json()
         return data.get("totalCount") == self.expected_total_count
@@ -110,14 +108,11 @@ class SearchRecordIdOperator(BaseOperator):
                 data=json.dumps(self.request_body),
                 extra_options={"check_response": False}
             )
-            if not response.ok and response.status_code not in (
-            http.HTTPStatus.FORBIDDEN, http.HTTPStatus.UNAUTHORIZED):
-                raise Exception("Error %s, text: %s." % (response.status_code, response.text))
-            if not self._file_searched(response):
+            if not self._is_record_searchable(response):
                 logger.error("Expected amount (%s) of records not found." %
                              self.expected_total_count
                              )
-                raise Exception
+                raise RecordsNotSearchableError
             return response
         else:
             logger.error("There is an error in header or in request body")
diff --git a/src/plugins/operators/update_status_op.py b/src/plugins/operators/update_status_op.py
index 22f4dd1..1acdc6d 100644
--- a/src/plugins/operators/update_status_op.py
+++ b/src/plugins/operators/update_status_op.py
@@ -15,17 +15,16 @@
 
 
 import enum
-import http
 import json
 import logging
 import sys
 from functools import partial
 
-import requests
 import tenacity
 from airflow.models import BaseOperator, Variable
 from airflow.utils.decorators import apply_defaults
 from hooks import search_http_hook, workflow_hook
+from libs.exceptions import PipelineFailedError
 from libs.refresh_token import refresh_token
 
 # Set up base logger
@@ -68,15 +67,12 @@ class UpdateStatusOperator(BaseOperator):
         data_conf = kwargs['dag_run'].conf
         # for /submitWithManifest authorization and partition-id are inside Payload field
         if "Payload" in data_conf:
-            # auth = data_conf["Payload"]["authorization"]
             partition_id = data_conf["Payload"]["data-partition-id"]
         else:
-            # auth = data_conf["authorization"]
             partition_id = data_conf["data-partition-id"]
         headers = {
             'Content-type': 'application/json',
             'data-partition-id': partition_id,
-            # 'Authorization': auth,
         }
         return headers
 
@@ -141,12 +137,12 @@ class UpdateStatusOperator(BaseOperator):
            If they are then update status FINISHED else FAILED
         """
         headers = self.get_headers(**context)
-        self.update_status_rqst(headers, self.status, **context)
+        self.update_status_request(headers, self.status, **context)
         if self.status == self.FAILED_STATUS:
-            raise Exception("Dag failed")
+            raise PipelineFailedError("Dag failed")
 
     @refresh_token
-    def update_status_rqst(self, headers, status, **kwargs):
+    def update_status_request(self, headers, status, **kwargs):
         data_conf = kwargs['dag_run'].conf
         logger.info(f"Got dataconf {data_conf}")
         workflow_id = data_conf["WorkflowID"]
@@ -161,7 +157,4 @@ class UpdateStatusOperator(BaseOperator):
             headers=headers,
             extra_options={"check_response": False}
         )
-        if not response.ok and response.status_code not in (
-        http.HTTPStatus.UNAUTHORIZED, http.HTTPStatus.FORBIDDEN):
-            raise Exception("Error %s, text: %s." % (response.status_code, response.text))
         return response
diff --git a/tests/plugin-unit-tests/test_process_manifest_op.py b/tests/plugin-unit-tests/test_process_manifest_op.py
index 8b1ee50..3c516f4 100644
--- a/tests/plugin-unit-tests/test_process_manifest_op.py
+++ b/tests/plugin-unit-tests/test_process_manifest_op.py
@@ -23,38 +23,38 @@ import pytest
 sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins")
 
 from data import process_manifest_op as test_data
-from operators import process_manifest_op as p_m_op
+from operators import process_manifest_op
 
 
 @pytest.mark.parametrize(
-   "test_input, expected",
-   [
-      ("srn:type:work-product/WellLog:", "WellLog"),
-      ("srn:type:file/las2:", "las2"),
-   ]
+    "test_input, expected",
+    [
+        ("srn:type:work-product/WellLog:", "WellLog"),
+        ("srn:type:file/las2:", "las2"),
+    ]
 )
 def test_determine_data_type(test_input, expected):
-   data_type = p_m_op.determine_data_type(test_input)
-   assert data_type == expected
+    data_type = process_manifest_op.determine_data_type(test_input)
+    assert data_type == expected
 
 
 @pytest.mark.parametrize(
-   "data_type, loaded_conf, conf_payload, expected_file_result",
-   [
-       ("las2",
-        test_data.LOADED_CONF,
-        test_data.CONF_PAYLOAD,
-        test_data.PROCESS_FILE_ITEMS_RESULT)
-   ]
+    "data_type, loaded_conf, conf_payload, expected_file_result",
+    [
+        ("las2",
+         test_data.LOADED_CONF,
+         test_data.CONF_PAYLOAD,
+         test_data.PROCESS_FILE_ITEMS_RESULT)
+    ]
 )
 def test_process_file_items(data_type, loaded_conf, conf_payload, expected_file_result):
-   file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:")
-   expected_file_list = expected_file_result[0]
-   file_list, file_ids = p_m_op.process_file_items(loaded_conf, conf_payload)
-   for i in file_ids:
-      assert file_id_regex.match(i)
-
-   for i in file_list:
-      assert file_id_regex.match(i[0]["data"]["ResourceID"])
-      i[0]["data"]["ResourceID"] = ""
-   assert file_list == expected_file_list
+    file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:")
+    expected_file_list = expected_file_result[0]
+    file_list, file_ids = process_manifest_op.process_file_items(loaded_conf, conf_payload)
+    for i in file_ids:
+        assert file_id_regex.match(i)
+
+    for i in file_list:
+        assert file_id_regex.match(i[0]["data"]["ResourceID"])
+        i[0]["data"]["ResourceID"] = ""
+    assert file_list == expected_file_list
diff --git a/tests/test_dags.py b/tests/test_dags.py
index 648bb3a..ffd6a05 100644
--- a/tests/test_dags.py
+++ b/tests/test_dags.py
@@ -19,27 +19,26 @@ import time
 
 
 class DagStatus(enum.Enum):
-    RUNNING = enum.auto()
-    FAILED = enum.auto()
-    FINISHED = enum.auto()
+    RUNNING = "running"
+    FAILED = "failed"
+    FINISHED = "finished"
+
 
 OSDU_INGEST_SUCCESS_SH = "/mock-server/./test-osdu-ingest-success.sh"
 OSDU_INGEST_FAIL_SH = "/mock-server/./test-osdu-ingest-fail.sh"
 DEFAULT_INGEST_SUCCESS_SH = "/mock-server/./test-default-ingest-success.sh"
 DEFAULT_INGEST_FAIL_SH = "/mock-server/./test-default-ingest-fail.sh"
 
-with open("/tmp/osdu_ingest_result", "w") as f:
-    f.close()
-
 subprocess.run(f"/bin/bash -c 'airflow scheduler > /dev/null 2>&1 &'", shell=True)
 
-def check_dag_status(dag_name):
+
+def check_dag_status(dag_name: str) -> DagStatus:
     time.sleep(5)
     output = subprocess.getoutput(f'airflow list_dag_runs {dag_name}')
     if "failed" in output:
         print(dag_name)
         print(output)
-        return  DagStatus.FAILED
+        return DagStatus.FAILED
     if "running" in output:
         return DagStatus.RUNNING
     print(dag_name)
@@ -47,32 +46,17 @@ def check_dag_status(dag_name):
     return DagStatus.FINISHED
 
 
-def test_dag_success(dag_name, script):
-    print(f"Test {dag_name} success")
-    subprocess.run(f"{script}", shell=True)
-    while True:
-        dag_status = check_dag_status(dag_name)
-        if dag_status is DagStatus.RUNNING:
-            continue
-        elif dag_status is DagStatus.FINISHED:
-            return
-        else:
-            raise Exception(f"Error {dag_name} supposed to be finished")
-
-def test_dag_fail(dag_name, script):
+def test_dag_execution_result(dag_name: str, script: str, expected_status: DagStatus):
     subprocess.run(f"{script}", shell=True)
-    print(f"Expecting {dag_name} fail")
+    print(f"Expecting {dag_name} to be {expected_status.value}")
     while True:
         dag_status = check_dag_status(dag_name)
-        if dag_status is DagStatus.RUNNING:
-            continue
-        elif dag_status is DagStatus.FAILED:
-            return
-        else:
-            raise Exception(f"Error {dag_name} supposed to be failed")
+        if dag_status is not DagStatus.RUNNING:
+            break
+    assert dag_status is expected_status, f"Error {dag_name} supposed to be {expected_status.value}"
 
 
-test_dag_success("Osdu_ingest", OSDU_INGEST_SUCCESS_SH)
-test_dag_fail("Osdu_ingest", OSDU_INGEST_FAIL_SH)
-test_dag_success("Default_ingest", DEFAULT_INGEST_SUCCESS_SH)
-test_dag_fail("Default_ingest", DEFAULT_INGEST_FAIL_SH)
+test_dag_execution_result("Osdu_ingest", OSDU_INGEST_SUCCESS_SH, DagStatus.FINISHED)
+test_dag_execution_result("Osdu_ingest", OSDU_INGEST_FAIL_SH, DagStatus.FAILED)
+test_dag_execution_result("Default_ingest", DEFAULT_INGEST_SUCCESS_SH, DagStatus.FINISHED)
+test_dag_execution_result("Default_ingest", DEFAULT_INGEST_FAIL_SH, DagStatus.FAILED)
-- 
GitLab