From 8fa46a6c0878eeb9321a1b29f270a1fac6625afe Mon Sep 17 00:00:00 2001
From: Ernesto Gutierrez <ernesto_gutierrez@epam.com>
Date: Mon, 8 Feb 2021 07:37:40 -0600
Subject: [PATCH] GONRG-1689: Pulling authorize logic from Python SDK.

---
 src/dags/libs/handle_file.py                  |   6 +-
 src/dags/libs/manifest_analyzer.py            |   5 +-
 src/dags/libs/process_manifest_r3.py          |   4 +-
 src/dags/libs/refresh_token.py                | 187 +-----------------
 src/dags/libs/search_record_ids.py            |   4 +-
 src/dags/libs/update_status.py                |   6 +-
 src/dags/libs/validate_schema.py              |   6 +-
 src/plugins/operators/process_manifest_r2.py  |   6 +-
 tests/plugin-unit-tests/test_refresh_token.py |   8 +-
 9 files changed, 25 insertions(+), 207 deletions(-)

diff --git a/src/dags/libs/handle_file.py b/src/dags/libs/handle_file.py
index c90d355..a8186e7 100644
--- a/src/dags/libs/handle_file.py
+++ b/src/dags/libs/handle_file.py
@@ -28,7 +28,7 @@ from libs.constants import RETRIES, WAIT
 from libs.context import Context
 from libs.exceptions import InvalidFileRecordData
 from libs.mixins import HeadersMixin
-from libs.refresh_token import TokenRefresher, refresh_token
+from osdu_api.libs.auth.authorization import TokenRefresher, authorize
 from providers import blob_storage
 from providers.types import BlobStorageClient, FileLikeObject
 
@@ -124,7 +124,7 @@ class FileHandler(HeadersMixin):
                                            kind=None)
 
     @tenacity.retry(**RETRY_SETTINGS)
-    @refresh_token()
+    @authorize()
     def _send_post_request(self, headers: dict, url: str, request_body: str) -> requests.Response:
         logger.debug(f"{request_body}")
         response = requests.post(url, request_body, headers=headers)
@@ -132,7 +132,7 @@ class FileHandler(HeadersMixin):
         return response
 
     @tenacity.retry(**RETRY_SETTINGS)
-    @refresh_token()
+    @authorize()
     def _send_get_request(self, headers: dict, url: str) -> requests.Response:
         response = requests.get(url, headers=headers)
         logger.debug(response)
diff --git a/src/dags/libs/manifest_analyzer.py b/src/dags/libs/manifest_analyzer.py
index 839077e..9bd89eb 100644
--- a/src/dags/libs/manifest_analyzer.py
+++ b/src/dags/libs/manifest_analyzer.py
@@ -26,8 +26,9 @@ import toposort
 from libs.constants import RETRIES, TIMEOUT
 from libs.context import Context
 from libs.mixins import HeadersMixin
-from libs.refresh_token import TokenRefresher, refresh_token
+from libs.refresh_token import TokenRefresher
 from libs.traverse_manifest import ManifestEntity
+from osdu_api.libs.auth.authorization import authorize
 
 logger = logging.getLogger()
 
@@ -153,7 +154,7 @@ class ManifestAnalyzer(HeadersMixin):
         stop=tenacity.stop_after_attempt(RETRIES),
         reraise=True
     )
-    @refresh_token()
+    @authorize()
     def _get_storage_record_request(self, headers: dict, srn: str) -> requests.Response:
         logger.debug(f"Searching for {srn}")
         return requests.get(f"{self.storage_service_url}/{srn}", headers=headers)
diff --git a/src/dags/libs/process_manifest_r3.py b/src/dags/libs/process_manifest_r3.py
index 2959240..86f06d2 100644
--- a/src/dags/libs/process_manifest_r3.py
+++ b/src/dags/libs/process_manifest_r3.py
@@ -27,10 +27,10 @@ from libs.context import Context
 from libs.constants import RETRIES, WAIT
 from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError
 from libs.mixins import HeadersMixin
-from libs.refresh_token import TokenRefresher, refresh_token
 from libs.source_file_check import SourceFileChecker
 from libs.handle_file import FileHandler
 from libs.traverse_manifest import ManifestEntity
+from osdu_api.libs.auth.authorization import TokenRefresher, authorize
 
 RETRY_SETTINGS = {
     "stop": tenacity.stop_after_attempt(RETRIES),
@@ -156,7 +156,7 @@ class ManifestProcessor(HeadersMixin):
             raise ValueError(f"Invalid answer from Storage server: {response_dict}")
 
     @tenacity.retry(**RETRY_SETTINGS)
-    @refresh_token()
+    @authorize()
     def save_record_to_storage(self, headers: dict, request_data: List[dict]) -> requests.Response:
         """
         Send request to record storage API.
diff --git a/src/dags/libs/refresh_token.py b/src/dags/libs/refresh_token.py
index 6146da7..284b61c 100644
--- a/src/dags/libs/refresh_token.py
+++ b/src/dags/libs/refresh_token.py
@@ -21,11 +21,9 @@ import os
 from abc import ABC, abstractmethod
 from functools import partial
 from http import HTTPStatus
-from typing import Any, Callable, Union
 
 import requests
-from libs.exceptions import TokenRefresherNotPresentError
-from osdu_api.libs.auth.authorization import TokenRefresher as OSDUAPITokenRefresher
+from osdu_api.libs.auth.authorization import TokenRefresher, authorize
 from providers import credentials
 from providers.types import BaseCredentials
 from tenacity import retry, stop_after_attempt
@@ -35,42 +33,7 @@ logger = logging.getLogger(__name__)
 RETRIES = 3
 
 
-class TokenRefresher(ABC):
-    """Abstract base class for token refreshers."""
-
-    @abstractmethod
-    def refresh_token(self) -> str:
-        """Refresh auth token.
-
-        :return: refreshed token
-        :rtype: str
-        """
-        pass
-
-    @property
-    @abstractmethod
-    def access_token(self) -> str:
-        """Auth access token.
-
-        :return: token string
-        :rtype: str
-        """
-        pass
-
-    @property
-    @abstractmethod
-    def authorization_header(self) -> dict:
-        """Authorization header. Must return  authorization header for
-        updating headers dict.
-        E.g. return {"Authorization": "Bearer <access_token>"}
-
-        :return: A dictionary with authorization header updated
-        :rtype: dict
-        """
-        pass
-
-
-class AirflowTokenRefresher(TokenRefresher, OSDUAPITokenRefresher):
+class AirflowTokenRefresher(TokenRefresher):
     """Simple wrapper for credentials to be used in refresh_token decorator within Airflow."""
 
     def __init__(self, creds: BaseCredentials = None):
@@ -113,149 +76,3 @@ class AirflowTokenRefresher(TokenRefresher, OSDUAPITokenRefresher):
         :rtype: dict
         """
         return {"Authorization": f"Bearer {self.access_token}"}
-
-
-def make_callable_request(obj: Union[object, None], request_function: Callable, headers: dict,
-                          *args, **kwargs) -> Callable:
-    """Create send_request_with_auth function.
-
-    :param obj: If wrapping a method the obj passed as first argument (self)
-    :type obj: Union[object, None]
-    :param request_function: The function to be build
-    :type request_function: Callable
-    :param headers: The request headers
-    :type headers: dict
-    :return: A partial callable
-    :rtype: Callable
-    """
-    if obj:  # if wrapped function is an object's method
-        callable_request = partial(request_function, obj, headers, *args, **kwargs)
-    else:
-        callable_request = partial(request_function, headers, *args, **kwargs)
-    return callable_request
-
-
-def _validate_headers_type(headers: dict):
-    if not isinstance(headers, dict):
-        logger.error(f"Got headers {headers}")
-        raise TypeError("Request's headers type expected to be 'dict'")
-
-
-def _validate_response_type(response: requests.Response, request_function: Callable):
-    if not isinstance(response, requests.Response):
-        logger.error(f"Function or method {request_function}"
-                     f" must return values of type 'requests.Response'. "
-                     f"Got {type(response)} instead")
-        raise TypeError
-
-
-def _validate_token_refresher_type(token_refresher: TokenRefresher):
-    if not isinstance(token_refresher, TokenRefresher):
-        raise TypeError(
-            f"Token refresher must be of type {TokenRefresher}. Got {type(token_refresher)}"
-        )
-
-
-def _get_object_token_refresher(
-    token_refresher: TokenRefresher,
-    obj: object = None
-) -> TokenRefresher:
-    """Get token refresher passed into decorator or specified in object's as
-    'token_refresher' property.
-
-    :param token_refresher: A token refresher instance
-    :type token_refresher: TokenRefresher
-    :param obj: The holder object of the decorated method, defaults to None
-    :type obj: object, optional
-    :raises TokenRefresherNotPresentError: When not found
-    :return: The token refresher
-    :rtype: TokenRefresher
-    """
-    if token_refresher is None and obj:
-        try:
-            obj.__getattribute__("token_refresher")
-        except AttributeError:
-            raise TokenRefresherNotPresentError("Token refresher must be passed into decorator or "
-                                                "set as object's 'refresh_token' attribute.")
-        else:
-            token_refresher = obj.token_refresher
-    return token_refresher
-
-
-def send_request_with_auth_header(token_refresher: TokenRefresher, *args,
-                                  **kwargs) -> requests.Response:
-    """Send request with authorization token. If response status is in
-    HTTPStatus.UNAUTHORIZED or HTTPStatus.FORBIDDEN, then refreshes token
-    and sends request once again.
-
-    :param token_refresher: The token refresher instance
-    :type token_refresher: TokenRefresher
-    :raises e: Re-raises any requests.HTTPError
-    :return: The server response
-    :rtype: requests.Response
-    """
-    obj = kwargs.pop("obj", None)
-    request_function = kwargs.pop("request_function")
-    headers = kwargs.pop("headers")
-    _validate_headers_type(headers)
-    headers.update(token_refresher.authorization_header)
-
-    send_request_with_auth = make_callable_request(obj, request_function, headers, *args, **kwargs)
-    response = send_request_with_auth()
-    _validate_response_type(response, request_function)
-
-    if not response.ok:
-        if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
-            token_refresher.refresh_token()
-            headers.update(token_refresher.authorization_header)
-            send_request_with_auth = make_callable_request(obj,
-                                                           request_function,
-                                                           headers,
-                                                           *args, **kwargs)
-            response = send_request_with_auth()
-        try:
-            response.raise_for_status()
-        except requests.HTTPError as e:
-            logger.error(f"{response.text}")
-            raise e
-    return response
-
-
-def refresh_token(token_refresher: TokenRefresher = None) -> Callable:
-    """Wrap a request function and check response. If response's error status
-    code is about Authorization, refresh token and invoke this function once
-    again.
-    Expects function.
-    If response is not ok and not about Authorization, then raises HTTPError
-    request_func(header: dict, *args, **kwargs) -> requests.Response
-    Or method:
-    request_method(self, header: dict, *args, **kwargs) -> requests.Response
-
-    :param token_refresher: [description], defaults to None
-    :type token_refresher: TokenRefresher, optional
-    :return: [description]
-    :rtype: Callable
-    """
-
-    def refresh_token_wrapper(request_function: Callable) -> Callable:
-        is_method = len(request_function.__qualname__.split(".")) > 1
-        if is_method:
-            def _wrapper(obj: object, headers: dict, *args, **kwargs) -> requests.Response:
-                _token_refresher = _get_object_token_refresher(token_refresher, obj)
-                _validate_token_refresher_type(_token_refresher)
-                return send_request_with_auth_header(_token_refresher,
-                                                     request_function=request_function,
-                                                     obj=obj,
-                                                     headers=headers,
-                                                     *args,
-                                                     **kwargs)
-        else:
-            def _wrapper(headers: dict, *args, **kwargs) -> requests.Response:
-                _validate_token_refresher_type(token_refresher)
-                return send_request_with_auth_header(token_refresher,
-                                                     request_function=request_function,
-                                                     headers=headers,
-                                                     *args, **kwargs)
-        return _wrapper
-
-    return refresh_token_wrapper
diff --git a/src/dags/libs/search_record_ids.py b/src/dags/libs/search_record_ids.py
index e45a028..f21620d 100644
--- a/src/dags/libs/search_record_ids.py
+++ b/src/dags/libs/search_record_ids.py
@@ -23,7 +23,7 @@ import tenacity
 from libs.context import Context
 from libs.exceptions import RecordsNotSearchableError
 from libs.mixins import HeadersMixin
-from libs.refresh_token import TokenRefresher, refresh_token
+from osdu_api.libs.auth.authorization import TokenRefresher, authorize
 
 logger = logging.getLogger()
 
@@ -103,7 +103,7 @@ class SearchId(HeadersMixin):
         stop=tenacity.stop_after_attempt(RETRIES),
         reraise=True
     )
-    @refresh_token()
+    @authorize()
     def search_files(self, headers: dict) -> requests.Response:
         """Send request with recordIds to Search service.
 
diff --git a/src/dags/libs/update_status.py b/src/dags/libs/update_status.py
index d6ac935..afc0d4f 100644
--- a/src/dags/libs/update_status.py
+++ b/src/dags/libs/update_status.py
@@ -22,7 +22,7 @@ import requests
 
 from libs.context import Context
 from libs.mixins import HeadersMixin
-from libs.refresh_token import TokenRefresher, refresh_token
+from osdu_api.libs.auth.authorization import TokenRefresher, authorize
 
 logger = logging.getLogger()
 
@@ -66,7 +66,7 @@ class UpdateStatus(HeadersMixin):
         self.status = status
         self.token_refresher = token_refresher
 
-    @refresh_token()
+    @authorize()
     def update_status_request(self, headers: dict) -> requests.Response:
         """Send request to update status.
 
@@ -85,7 +85,7 @@ class UpdateStatus(HeadersMixin):
         response = requests.put(update_status_url, request_body, headers=headers)
         return response
 
-    @refresh_token()
+    @authorize()
     def update_status_request_old(self, headers: dict) -> requests.Response:
         """Send request to update status.
 
diff --git a/src/dags/libs/validate_schema.py b/src/dags/libs/validate_schema.py
index fa67074..9006e40 100644
--- a/src/dags/libs/validate_schema.py
+++ b/src/dags/libs/validate_schema.py
@@ -17,7 +17,7 @@
 
 import copy
 import logging
-from typing import Union, Any, List
+from typing import Any, List, Union
 
 import jsonschema
 import requests
@@ -27,7 +27,7 @@ from libs.context import Context
 from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError
 from libs.traverse_manifest import ManifestEntity
 from libs.mixins import HeadersMixin
-from libs.refresh_token import TokenRefresher, refresh_token
+from osdu_api.libs.auth.authorization import TokenRefresher, authorize
 
 logger = logging.getLogger()
 
@@ -100,7 +100,7 @@ class SchemaValidator(HeadersMixin):
         stop=tenacity.stop_after_attempt(RETRIES),
         reraise=True
     )
-    @refresh_token()
+    @authorize()
     def _get_schema_from_schema_service(self, headers: dict, uri: str) -> requests.Response:
         """Send request to Schema service to retrieve schema."""
         response = requests.get(uri, headers=headers, timeout=60)
diff --git a/src/plugins/operators/process_manifest_r2.py b/src/plugins/operators/process_manifest_r2.py
index e85ad3a..6a7e6f6 100644
--- a/src/plugins/operators/process_manifest_r2.py
+++ b/src/plugins/operators/process_manifest_r2.py
@@ -30,8 +30,8 @@ from urllib.error import HTTPError
 import requests
 import tenacity
 from airflow.models import BaseOperator, Variable
-from libs.refresh_token import AirflowTokenRefresher, refresh_token
-
+from libs.refresh_token import AirflowTokenRefresher
+from osdu_api.libs.auth.authorization import authorize
 
 config = configparser.RawConfigParser()
 config.read(Variable.get("dataload_config_path"))
@@ -306,7 +306,7 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp
     stop=tenacity.stop_after_attempt(RETRIES),
     reraise=True
 )
-@refresh_token(AirflowTokenRefresher())
+@authorize(AirflowTokenRefresher())
 def send_request(headers, request_data):
     """Send request to records storage API."""
 
diff --git a/tests/plugin-unit-tests/test_refresh_token.py b/tests/plugin-unit-tests/test_refresh_token.py
index 49d98a2..387d402 100644
--- a/tests/plugin-unit-tests/test_refresh_token.py
+++ b/tests/plugin-unit-tests/test_refresh_token.py
@@ -23,7 +23,7 @@ from google.oauth2 import service_account
 
 sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
 
-from libs.refresh_token import AirflowTokenRefresher, TokenRefresher
+from libs.refresh_token import AirflowTokenRefresher
 from mock_providers import get_test_credentials
 
 
@@ -42,7 +42,7 @@ class TestAirflowTokenRefresher:
             "test"
         ]
     )
-    def test_access_token_cached(self, token_refresher: TokenRefresher, access_token: str):
+    def test_access_token_cached(self, token_refresher: AirflowTokenRefresher, access_token: str):
         """
         Check if access token stored in Airflow Variables after refreshing it.
         """
@@ -56,7 +56,7 @@ class TestAirflowTokenRefresher:
             "aaaa"
         ]
     )
-    def test_authorization_header(self, token_refresher: TokenRefresher, access_token: str):
+    def test_authorization_header(self, token_refresher: AirflowTokenRefresher, access_token: str):
         """
         Check if Authorization header is 'Bearer <access_token>'
         """
@@ -73,7 +73,7 @@ class TestAirflowTokenRefresher:
     def test_refresh_token_no_cached_variable(
         self,
         monkeypatch,
-        token_refresher: TokenRefresher,
+        token_refresher: AirflowTokenRefresher,
         access_token: str,
     ):
         """
-- 
GitLab