Skip to content
Snippets Groups Projects
Commit 303808d0 authored by Siarhei Khaletski (EPAM)'s avatar Siarhei Khaletski (EPAM) :triangular_flag_on_post:
Browse files

Merge branch 'feature/GONRG-1689_Refactor_add_Python_SDK' into 'integration-master'

GONRG-1689 "Feature/ refactor add python sdk"

Closes GONRG-1689

See merge request go3-nrg/platform/data-flow/ingestion/ingestion-dags!56
parents e39299c7 6a47b67e
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ from libs.constants import RETRIES, WAIT
from libs.context import Context
from libs.exceptions import InvalidFileRecordData
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
from providers import blob_storage
from providers.types import BlobStorageClient, FileLikeObject
......@@ -124,7 +124,7 @@ class FileHandler(HeadersMixin):
kind=None)
@tenacity.retry(**RETRY_SETTINGS)
@refresh_token()
@authorize()
def _send_post_request(self, headers: dict, url: str, request_body: str) -> requests.Response:
logger.debug(f"{request_body}")
response = requests.post(url, request_body, headers=headers)
......@@ -132,7 +132,7 @@ class FileHandler(HeadersMixin):
return response
@tenacity.retry(**RETRY_SETTINGS)
@refresh_token()
@authorize()
def _send_get_request(self, headers: dict, url: str) -> requests.Response:
response = requests.get(url, headers=headers)
logger.debug(response)
......
......@@ -26,8 +26,9 @@ import toposort
from libs.constants import RETRIES, TIMEOUT
from libs.context import Context
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from libs.refresh_token import TokenRefresher
from libs.traverse_manifest import ManifestEntity
from osdu_api.libs.auth.authorization import authorize
logger = logging.getLogger()
......@@ -153,7 +154,7 @@ class ManifestAnalyzer(HeadersMixin):
stop=tenacity.stop_after_attempt(RETRIES),
reraise=True
)
@refresh_token()
@authorize()
def _get_storage_record_request(self, headers: dict, srn: str) -> requests.Response:
logger.debug(f"Searching for {srn}")
return requests.get(f"{self.storage_service_url}/{srn}", headers=headers)
......
......@@ -27,10 +27,10 @@ from libs.context import Context
from libs.constants import RETRIES, WAIT
from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from libs.source_file_check import SourceFileChecker
from libs.handle_file import FileHandler
from libs.traverse_manifest import ManifestEntity
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
RETRY_SETTINGS = {
"stop": tenacity.stop_after_attempt(RETRIES),
......@@ -156,7 +156,7 @@ class ManifestProcessor(HeadersMixin):
raise ValueError(f"Invalid answer from Storage server: {response_dict}")
@tenacity.retry(**RETRY_SETTINGS)
@refresh_token()
@authorize()
def save_record_to_storage(self, headers: dict, request_data: List[dict]) -> requests.Response:
"""
Send request to record storage API.
......
......@@ -21,11 +21,9 @@ import os
from abc import ABC, abstractmethod
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Union
import requests
from libs.exceptions import TokenRefresherNotPresentError
from osdu_api.libs.auth.authorization import TokenRefresher as OSDUAPITokenRefresher
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
from providers import credentials
from providers.types import BaseCredentials
from tenacity import retry, stop_after_attempt
......@@ -35,42 +33,7 @@ logger = logging.getLogger(__name__)
RETRIES = 3
class TokenRefresher(ABC):
"""Abstract base class for token refreshers."""
@abstractmethod
def refresh_token(self) -> str:
"""Refresh auth token.
:return: refreshed token
:rtype: str
"""
pass
@property
@abstractmethod
def access_token(self) -> str:
"""Auth access token.
:return: token string
:rtype: str
"""
pass
@property
@abstractmethod
def authorization_header(self) -> dict:
"""Authorization header. Must return authorization header for
updating headers dict.
E.g. return {"Authorization": "Bearer <access_token>"}
:return: A dictionary with authorization header updated
:rtype: dict
"""
pass
class AirflowTokenRefresher(TokenRefresher, OSDUAPITokenRefresher):
class AirflowTokenRefresher(TokenRefresher):
"""Simple wrapper for credentials to be used in refresh_token decorator within Airflow."""
def __init__(self, creds: BaseCredentials = None):
......@@ -113,149 +76,3 @@ class AirflowTokenRefresher(TokenRefresher, OSDUAPITokenRefresher):
:rtype: dict
"""
return {"Authorization": f"Bearer {self.access_token}"}
def make_callable_request(obj: Union[object, None], request_function: Callable, headers: dict,
*args, **kwargs) -> Callable:
"""Create send_request_with_auth function.
:param obj: If wrapping a method the obj passed as first argument (self)
:type obj: Union[object, None]
:param request_function: The function to be build
:type request_function: Callable
:param headers: The request headers
:type headers: dict
:return: A partial callable
:rtype: Callable
"""
if obj: # if wrapped function is an object's method
callable_request = partial(request_function, obj, headers, *args, **kwargs)
else:
callable_request = partial(request_function, headers, *args, **kwargs)
return callable_request
def _validate_headers_type(headers: dict):
if not isinstance(headers, dict):
logger.error(f"Got headers {headers}")
raise TypeError("Request's headers type expected to be 'dict'")
def _validate_response_type(response: requests.Response, request_function: Callable):
if not isinstance(response, requests.Response):
logger.error(f"Function or method {request_function}"
f" must return values of type 'requests.Response'. "
f"Got {type(response)} instead")
raise TypeError
def _validate_token_refresher_type(token_refresher: TokenRefresher):
if not isinstance(token_refresher, TokenRefresher):
raise TypeError(
f"Token refresher must be of type {TokenRefresher}. Got {type(token_refresher)}"
)
def _get_object_token_refresher(
token_refresher: TokenRefresher,
obj: object = None
) -> TokenRefresher:
"""Get token refresher passed into decorator or specified in object's as
'token_refresher' property.
:param token_refresher: A token refresher instance
:type token_refresher: TokenRefresher
:param obj: The holder object of the decorated method, defaults to None
:type obj: object, optional
:raises TokenRefresherNotPresentError: When not found
:return: The token refresher
:rtype: TokenRefresher
"""
if token_refresher is None and obj:
try:
obj.__getattribute__("token_refresher")
except AttributeError:
raise TokenRefresherNotPresentError("Token refresher must be passed into decorator or "
"set as object's 'refresh_token' attribute.")
else:
token_refresher = obj.token_refresher
return token_refresher
def send_request_with_auth_header(token_refresher: TokenRefresher, *args,
**kwargs) -> requests.Response:
"""Send request with authorization token. If response status is in
HTTPStatus.UNAUTHORIZED or HTTPStatus.FORBIDDEN, then refreshes token
and sends request once again.
:param token_refresher: The token refresher instance
:type token_refresher: TokenRefresher
:raises e: Re-raises any requests.HTTPError
:return: The server response
:rtype: requests.Response
"""
obj = kwargs.pop("obj", None)
request_function = kwargs.pop("request_function")
headers = kwargs.pop("headers")
_validate_headers_type(headers)
headers.update(token_refresher.authorization_header)
send_request_with_auth = make_callable_request(obj, request_function, headers, *args, **kwargs)
response = send_request_with_auth()
_validate_response_type(response, request_function)
if not response.ok:
if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
token_refresher.refresh_token()
headers.update(token_refresher.authorization_header)
send_request_with_auth = make_callable_request(obj,
request_function,
headers,
*args, **kwargs)
response = send_request_with_auth()
try:
response.raise_for_status()
except requests.HTTPError as e:
logger.error(f"{response.text}")
raise e
return response
def refresh_token(token_refresher: TokenRefresher = None) -> Callable:
"""Wrap a request function and check response. If response's error status
code is about Authorization, refresh token and invoke this function once
again.
Expects function.
If response is not ok and not about Authorization, then raises HTTPError
request_func(header: dict, *args, **kwargs) -> requests.Response
Or method:
request_method(self, header: dict, *args, **kwargs) -> requests.Response
:param token_refresher: [description], defaults to None
:type token_refresher: TokenRefresher, optional
:return: [description]
:rtype: Callable
"""
def refresh_token_wrapper(request_function: Callable) -> Callable:
is_method = len(request_function.__qualname__.split(".")) > 1
if is_method:
def _wrapper(obj: object, headers: dict, *args, **kwargs) -> requests.Response:
_token_refresher = _get_object_token_refresher(token_refresher, obj)
_validate_token_refresher_type(_token_refresher)
return send_request_with_auth_header(_token_refresher,
request_function=request_function,
obj=obj,
headers=headers,
*args,
**kwargs)
else:
def _wrapper(headers: dict, *args, **kwargs) -> requests.Response:
_validate_token_refresher_type(token_refresher)
return send_request_with_auth_header(token_refresher,
request_function=request_function,
headers=headers,
*args, **kwargs)
return _wrapper
return refresh_token_wrapper
......@@ -23,7 +23,7 @@ import tenacity
from libs.context import Context
from libs.exceptions import RecordsNotSearchableError
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
logger = logging.getLogger()
......@@ -103,7 +103,7 @@ class SearchId(HeadersMixin):
stop=tenacity.stop_after_attempt(RETRIES),
reraise=True
)
@refresh_token()
@authorize()
def search_files(self, headers: dict) -> requests.Response:
"""Send request with recordIds to Search service.
......
......@@ -22,7 +22,7 @@ import requests
from libs.context import Context
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
logger = logging.getLogger()
......@@ -66,7 +66,7 @@ class UpdateStatus(HeadersMixin):
self.status = status
self.token_refresher = token_refresher
@refresh_token()
@authorize()
def update_status_request(self, headers: dict) -> requests.Response:
"""Send request to update status.
......@@ -85,7 +85,7 @@ class UpdateStatus(HeadersMixin):
response = requests.put(update_status_url, request_body, headers=headers)
return response
@refresh_token()
@authorize()
def update_status_request_old(self, headers: dict) -> requests.Response:
"""Send request to update status.
......
......@@ -17,7 +17,7 @@
import copy
import logging
from typing import Union, Any, List
from typing import Any, List, Union
import jsonschema
import requests
......@@ -27,7 +27,7 @@ from libs.context import Context
from libs.exceptions import EmptyManifestError, NotOSDUSchemaFormatError
from libs.traverse_manifest import ManifestEntity
from libs.mixins import HeadersMixin
from libs.refresh_token import TokenRefresher, refresh_token
from osdu_api.libs.auth.authorization import TokenRefresher, authorize
logger = logging.getLogger()
......@@ -100,7 +100,7 @@ class SchemaValidator(HeadersMixin):
stop=tenacity.stop_after_attempt(RETRIES),
reraise=True
)
@refresh_token()
@authorize()
def _get_schema_from_schema_service(self, headers: dict, uri: str) -> requests.Response:
"""Send request to Schema service to retrieve schema."""
response = requests.get(uri, headers=headers, timeout=60)
......
......@@ -104,7 +104,7 @@ class GoogleCloudStorageClient(BlobStorageClient):
bucket = self._storage_client.bucket(bucket_name)
blob = bucket.get_blob(source_blob_name)
file_as_bytes = blob.download_as_bytes()
file_as_bytes = blob.download_as_string()
logger.debug(f"File {source_blob_name} got from bucket {bucket_name}.")
return file_as_bytes, blob.content_type
......
......@@ -30,8 +30,8 @@ from urllib.error import HTTPError
import requests
import tenacity
from airflow.models import BaseOperator, Variable
from libs.refresh_token import AirflowTokenRefresher, refresh_token
from libs.refresh_token import AirflowTokenRefresher
from osdu_api.libs.auth.authorization import authorize
config = configparser.RawConfigParser()
config.read(Variable.get("dataload_config_path"))
......@@ -306,7 +306,7 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp
stop=tenacity.stop_after_attempt(RETRIES),
reraise=True
)
@refresh_token(AirflowTokenRefresher())
@authorize(AirflowTokenRefresher())
def send_request(headers, request_data):
"""Send request to records storage API."""
......
......@@ -23,7 +23,7 @@ from google.oauth2 import service_account
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
from libs.refresh_token import AirflowTokenRefresher, TokenRefresher
from libs.refresh_token import AirflowTokenRefresher
from mock_providers import get_test_credentials
......@@ -42,7 +42,7 @@ class TestAirflowTokenRefresher:
"test"
]
)
def test_access_token_cached(self, token_refresher: TokenRefresher, access_token: str):
def test_access_token_cached(self, token_refresher: AirflowTokenRefresher, access_token: str):
"""
Check if access token stored in Airflow Variables after refreshing it.
"""
......@@ -56,7 +56,7 @@ class TestAirflowTokenRefresher:
"aaaa"
]
)
def test_authorization_header(self, token_refresher: TokenRefresher, access_token: str):
def test_authorization_header(self, token_refresher: AirflowTokenRefresher, access_token: str):
"""
Check if Authorization header is 'Bearer <access_token>'
"""
......@@ -73,7 +73,7 @@ class TestAirflowTokenRefresher:
def test_refresh_token_no_cached_variable(
self,
monkeypatch,
token_refresher: TokenRefresher,
token_refresher: AirflowTokenRefresher,
access_token: str,
):
"""
......
......@@ -144,7 +144,7 @@ class TestGoogleCloudStorageClient:
client_mock.bucket.assert_called_with(bucket_name)
bucket_mock.get_blob.assert_called_with(blob_name)
blob_mock.download_as_bytes.assert_called_with()
blob_mock.download_as_string.assert_called_with()
@pytest.mark.parametrize("uri, bucket_name, blob_name, content_type", [
pytest.param("gs://bucket_test/name_test", "bucket_test", "name_test", "text/html"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment