From 8fa46a6c0878eeb9321a1b29f270a1fac6625afe Mon Sep 17 00:00:00 2001 From: Ernesto Gutierrez <ernesto_gutierrez@epam.com> Date: Mon, 8 Feb 2021 07:37:40 -0600 Subject: [PATCH] GONRG-1689: Pulling authorize logic from Python SDK. --- src/dags/libs/handle_file.py | 6 +- src/dags/libs/manifest_analyzer.py | 5 +- src/dags/libs/process_manifest_r3.py | 4 +- src/dags/libs/refresh_token.py | 187 +----------------- src/dags/libs/search_record_ids.py | 4 +- src/dags/libs/update_status.py | 6 +- src/dags/libs/validate_schema.py | 6 +- src/plugins/operators/process_manifest_r2.py | 6 +- tests/plugin-unit-tests/test_refresh_token.py | 8 +- 9 files changed, 25 insertions(+), 207 deletions(-) diff --git a/src/dags/libs/handle_file.py b/src/dags/libs/handle_file.py index c90d355..a8186e7 100644 --- a/src/dags/libs/handle_file.py +++ b/src/dags/libs/handle_file.py @@ -28,7 +28,7 @@ from libs.constants import RETRIES, WAIT from libs.context import Context from libs.exceptions import InvalidFileRecordData from libs.mixins import HeadersMixin -from libs.refresh_token import TokenRefresher, refresh_token +from osdu_api.libs.auth.authorization import TokenRefresher, authorize from providers import blob_storage from providers.types import BlobStorageClient, FileLikeObject @@ -124,7 +124,7 @@ class FileHandler(HeadersMixin): kind=None) @tenacity.retry(**RETRY_SETTINGS) - @refresh_token() + @authorize() def _send_post_request(self, headers: dict, url: str, request_body: str) -> requests.Response: logger.debug(f"{request_body}") response = requests.post(url, request_body, headers=headers) @@ -132,7 +132,7 @@ class FileHandler(HeadersMixin): return response @tenacity.retry(**RETRY_SETTINGS) - @refresh_token() + @authorize() def _send_get_request(self, headers: dict, url: str) -> requests.Response: response = requests.get(url, headers=headers) logger.debug(response) diff --git a/src/dags/libs/manifest_analyzer.py b/src/dags/libs/manifest_analyzer.py index 839077e..9bd89eb 100644 --- a/src/dags/libs/manifest_analyzer.py +++ b/src/dags/libs/manifest_analyzer.py @@ -26,8 +26,9 @@ import toposort from libs.constants import RETRIES, TIMEOUT from libs.context import Context from libs.mixins import HeadersMixin -from libs.refresh_token import TokenRefresher, refresh_token +from libs.refresh_token import TokenRefresher from libs.traverse_manifest import ManifestEntity +from osdu_api.libs.auth.authorization import authorize logger = logging.getLogger() @@ -153,7 +154,7 @@ class ManifestAnalyzer(HeadersMixin): stop=tenacity.stop_after_attempt(RETRIES), reraise=True ) - @refresh_token() + @authorize() def _get_storage_record_request(self, headers: dict, srn: str) -> requests.Response: logger.debug(f"Searching for {srn}") return requests.get(f"{self.storage_service_url}/{srn}", headers=headers) diff --git a/src/dags/libs/process_manifest_r3.py b/src/dags/libs/process_manifest_r3.py index 2959240..86f06d2 100644 --- a/src/dags/libs/process_manifest_r3.py +++ b/src/dags/libs/process_manifest_r3.py @@ -27,10 +27,10 @@ from libs.context import Context from libs.constants import RETRIES, WAIT from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError from libs.mixins import HeadersMixin -from libs.refresh_token import TokenRefresher, refresh_token from libs.source_file_check import SourceFileChecker from libs.handle_file import FileHandler from libs.traverse_manifest import ManifestEntity +from osdu_api.libs.auth.authorization import TokenRefresher, authorize RETRY_SETTINGS = { "stop": tenacity.stop_after_attempt(RETRIES), @@ -156,7 +156,7 @@ class ManifestProcessor(HeadersMixin): raise ValueError(f"Invalid answer from Storage server: {response_dict}") @tenacity.retry(**RETRY_SETTINGS) - @refresh_token() + @authorize() def save_record_to_storage(self, headers: dict, request_data: List[dict]) -> requests.Response: """ Send request to record storage API. diff --git a/src/dags/libs/refresh_token.py b/src/dags/libs/refresh_token.py index 6146da7..284b61c 100644 --- a/src/dags/libs/refresh_token.py +++ b/src/dags/libs/refresh_token.py @@ -21,11 +21,9 @@ import os from abc import ABC, abstractmethod from functools import partial from http import HTTPStatus -from typing import Any, Callable, Union import requests -from libs.exceptions import TokenRefresherNotPresentError -from osdu_api.libs.auth.authorization import TokenRefresher as OSDUAPITokenRefresher +from osdu_api.libs.auth.authorization import TokenRefresher, authorize from providers import credentials from providers.types import BaseCredentials from tenacity import retry, stop_after_attempt @@ -35,42 +33,7 @@ logger = logging.getLogger(__name__) RETRIES = 3 -class TokenRefresher(ABC): - """Abstract base class for token refreshers.""" - - @abstractmethod - def refresh_token(self) -> str: - """Refresh auth token. - - :return: refreshed token - :rtype: str - """ - pass - - @property - @abstractmethod - def access_token(self) -> str: - """Auth access token. - - :return: token string - :rtype: str - """ - pass - - @property - @abstractmethod - def authorization_header(self) -> dict: - """Authorization header. Must return authorization header for - updating headers dict. - E.g. return {"Authorization": "Bearer <access_token>"} - - :return: A dictionary with authorization header updated - :rtype: dict - """ - pass - - -class AirflowTokenRefresher(TokenRefresher, OSDUAPITokenRefresher): +class AirflowTokenRefresher(TokenRefresher): """Simple wrapper for credentials to be used in refresh_token decorator within Airflow.""" def __init__(self, creds: BaseCredentials = None): @@ -113,149 +76,3 @@ class AirflowTokenRefresher(TokenRefresher, OSDUAPITokenRefresher): :rtype: dict """ return {"Authorization": f"Bearer {self.access_token}"} - - -def make_callable_request(obj: Union[object, None], request_function: Callable, headers: dict, - *args, **kwargs) -> Callable: - """Create send_request_with_auth function. - - :param obj: If wrapping a method the obj passed as first argument (self) - :type obj: Union[object, None] - :param request_function: The function to be build - :type request_function: Callable - :param headers: The request headers - :type headers: dict - :return: A partial callable - :rtype: Callable - """ - if obj: # if wrapped function is an object's method - callable_request = partial(request_function, obj, headers, *args, **kwargs) - else: - callable_request = partial(request_function, headers, *args, **kwargs) - return callable_request - - -def _validate_headers_type(headers: dict): - if not isinstance(headers, dict): - logger.error(f"Got headers {headers}") - raise TypeError("Request's headers type expected to be 'dict'") - - -def _validate_response_type(response: requests.Response, request_function: Callable): - if not isinstance(response, requests.Response): - logger.error(f"Function or method {request_function}" - f" must return values of type 'requests.Response'. " - f"Got {type(response)} instead") - raise TypeError - - -def _validate_token_refresher_type(token_refresher: TokenRefresher): - if not isinstance(token_refresher, TokenRefresher): - raise TypeError( - f"Token refresher must be of type {TokenRefresher}. Got {type(token_refresher)}" - ) - - -def _get_object_token_refresher( - token_refresher: TokenRefresher, - obj: object = None -) -> TokenRefresher: - """Get token refresher passed into decorator or specified in object's as - 'token_refresher' property. - - :param token_refresher: A token refresher instance - :type token_refresher: TokenRefresher - :param obj: The holder object of the decorated method, defaults to None - :type obj: object, optional - :raises TokenRefresherNotPresentError: When not found - :return: The token refresher - :rtype: TokenRefresher - """ - if token_refresher is None and obj: - try: - obj.__getattribute__("token_refresher") - except AttributeError: - raise TokenRefresherNotPresentError("Token refresher must be passed into decorator or " - "set as object's 'refresh_token' attribute.") - else: - token_refresher = obj.token_refresher - return token_refresher - - -def send_request_with_auth_header(token_refresher: TokenRefresher, *args, - **kwargs) -> requests.Response: - """Send request with authorization token. If response status is in - HTTPStatus.UNAUTHORIZED or HTTPStatus.FORBIDDEN, then refreshes token - and sends request once again. - - :param token_refresher: The token refresher instance - :type token_refresher: TokenRefresher - :raises e: Re-raises any requests.HTTPError - :return: The server response - :rtype: requests.Response - """ - obj = kwargs.pop("obj", None) - request_function = kwargs.pop("request_function") - headers = kwargs.pop("headers") - _validate_headers_type(headers) - headers.update(token_refresher.authorization_header) - - send_request_with_auth = make_callable_request(obj, request_function, headers, *args, **kwargs) - response = send_request_with_auth() - _validate_response_type(response, request_function) - - if not response.ok: - if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): - token_refresher.refresh_token() - headers.update(token_refresher.authorization_header) - send_request_with_auth = make_callable_request(obj, - request_function, - headers, - *args, **kwargs) - response = send_request_with_auth() - try: - response.raise_for_status() - except requests.HTTPError as e: - logger.error(f"{response.text}") - raise e - return response - - -def refresh_token(token_refresher: TokenRefresher = None) -> Callable: - """Wrap a request function and check response. If response's error status - code is about Authorization, refresh token and invoke this function once - again. - Expects function. - If response is not ok and not about Authorization, then raises HTTPError - request_func(header: dict, *args, **kwargs) -> requests.Response - Or method: - request_method(self, header: dict, *args, **kwargs) -> requests.Response - - :param token_refresher: [description], defaults to None - :type token_refresher: TokenRefresher, optional - :return: [description] - :rtype: Callable - """ - - def refresh_token_wrapper(request_function: Callable) -> Callable: - is_method = len(request_function.__qualname__.split(".")) > 1 - if is_method: - def _wrapper(obj: object, headers: dict, *args, **kwargs) -> requests.Response: - _token_refresher = _get_object_token_refresher(token_refresher, obj) - _validate_token_refresher_type(_token_refresher) - return send_request_with_auth_header(_token_refresher, - request_function=request_function, - obj=obj, - headers=headers, - *args, - **kwargs) - else: - def _wrapper(headers: dict, *args, **kwargs) -> requests.Response: - _validate_token_refresher_type(token_refresher) - return send_request_with_auth_header(token_refresher, - request_function=request_function, - headers=headers, - *args, **kwargs) - return _wrapper - - return refresh_token_wrapper diff --git a/src/dags/libs/search_record_ids.py b/src/dags/libs/search_record_ids.py index e45a028..f21620d 100644 --- a/src/dags/libs/search_record_ids.py +++ b/src/dags/libs/search_record_ids.py @@ -23,7 +23,7 @@ import tenacity from libs.context import Context from libs.exceptions import RecordsNotSearchableError from libs.mixins import HeadersMixin -from libs.refresh_token import TokenRefresher, refresh_token +from osdu_api.libs.auth.authorization import TokenRefresher, authorize logger = logging.getLogger() @@ -103,7 +103,7 @@ class SearchId(HeadersMixin): stop=tenacity.stop_after_attempt(RETRIES), reraise=True ) - @refresh_token() + @authorize() def search_files(self, headers: dict) -> requests.Response: """Send request with recordIds to Search service. diff --git a/src/dags/libs/update_status.py b/src/dags/libs/update_status.py index d6ac935..afc0d4f 100644 --- a/src/dags/libs/update_status.py +++ b/src/dags/libs/update_status.py @@ -22,7 +22,7 @@ import requests from libs.context import Context from libs.mixins import HeadersMixin -from libs.refresh_token import TokenRefresher, refresh_token +from osdu_api.libs.auth.authorization import TokenRefresher, authorize logger = logging.getLogger() @@ -66,7 +66,7 @@ class UpdateStatus(HeadersMixin): self.status = status self.token_refresher = token_refresher - @refresh_token() + @authorize() def update_status_request(self, headers: dict) -> requests.Response: """Send request to update status. @@ -85,7 +85,7 @@ class UpdateStatus(HeadersMixin): response = requests.put(update_status_url, request_body, headers=headers) return response - @refresh_token() + @authorize() def update_status_request_old(self, headers: dict) -> requests.Response: """Send request to update status. diff --git a/src/dags/libs/validate_schema.py b/src/dags/libs/validate_schema.py index fa67074..9006e40 100644 --- a/src/dags/libs/validate_schema.py +++ b/src/dags/libs/validate_schema.py @@ -17,7 +17,7 @@ import copy import logging -from typing import Union, Any, List +from typing import Any, List, Union import jsonschema import requests @@ -27,7 +27,7 @@ from libs.context import Context from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError from libs.traverse_manifest import ManifestEntity from libs.mixins import HeadersMixin -from libs.refresh_token import TokenRefresher, refresh_token +from osdu_api.libs.auth.authorization import TokenRefresher, authorize logger = logging.getLogger() @@ -100,7 +100,7 @@ class SchemaValidator(HeadersMixin): stop=tenacity.stop_after_attempt(RETRIES), reraise=True ) - @refresh_token() + @authorize() def _get_schema_from_schema_service(self, headers: dict, uri: str) -> requests.Response: """Send request to Schema service to retrieve schema.""" response = requests.get(uri, headers=headers, timeout=60) diff --git a/src/plugins/operators/process_manifest_r2.py b/src/plugins/operators/process_manifest_r2.py index e85ad3a..6a7e6f6 100644 --- a/src/plugins/operators/process_manifest_r2.py +++ b/src/plugins/operators/process_manifest_r2.py @@ -30,8 +30,8 @@ from urllib.error import HTTPError import requests import tenacity from airflow.models import BaseOperator, Variable -from libs.refresh_token import AirflowTokenRefresher, refresh_token - +from libs.refresh_token import AirflowTokenRefresher +from osdu_api.libs.auth.authorization import authorize config = configparser.RawConfigParser() config.read(Variable.get("dataload_config_path")) @@ -306,7 +306,7 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp stop=tenacity.stop_after_attempt(RETRIES), reraise=True ) -@refresh_token(AirflowTokenRefresher()) +@authorize(AirflowTokenRefresher()) def send_request(headers, request_data): """Send request to records storage API.""" diff --git a/tests/plugin-unit-tests/test_refresh_token.py b/tests/plugin-unit-tests/test_refresh_token.py index 49d98a2..387d402 100644 --- a/tests/plugin-unit-tests/test_refresh_token.py +++ b/tests/plugin-unit-tests/test_refresh_token.py @@ -23,7 +23,7 @@ from google.oauth2 import service_account sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") -from libs.refresh_token import AirflowTokenRefresher, TokenRefresher +from libs.refresh_token import AirflowTokenRefresher from mock_providers import get_test_credentials @@ -42,7 +42,7 @@ class TestAirflowTokenRefresher: "test" ] ) - def test_access_token_cached(self, token_refresher: TokenRefresher, access_token: str): + def test_access_token_cached(self, token_refresher: AirflowTokenRefresher, access_token: str): """ Check if access token stored in Airflow Variables after refreshing it. """ @@ -56,7 +56,7 @@ class TestAirflowTokenRefresher: "aaaa" ] ) - def test_authorization_header(self, token_refresher: TokenRefresher, access_token: str): + def test_authorization_header(self, token_refresher: AirflowTokenRefresher, access_token: str): """ Check if Authorization header is 'Bearer <access_token>' """ @@ -73,7 +73,7 @@ class TestAirflowTokenRefresher: def test_refresh_token_no_cached_variable( self, monkeypatch, - token_refresher: TokenRefresher, + token_refresher: AirflowTokenRefresher, access_token: str, ): """ -- GitLab