From 43c5ff474d103dedae193fac54dfb984980e0ab0 Mon Sep 17 00:00:00 2001 From: yan <yan_sushchynski@epam.com> Date: Fri, 20 Nov 2020 11:47:58 +0300 Subject: [PATCH] GONRG-1180: Add refresh token using bucket --- scripts/convert_R2_schemas_to_R3.py | 81 ------ src/dags/libs/__init__.py | 14 + src/dags/libs/create_records.py | 4 +- src/dags/libs/exceptions.py | 12 + src/dags/libs/refresh_token.py | 263 +++++++++++++----- src/dags/osdu-ingest-r3.py | 10 +- .../operators/process_manifest_r2_op.py | 4 +- src/plugins/operators/process_manifest_r3.py | 6 +- src/plugins/operators/search_record_id_op.py | 4 +- src/plugins/operators/update_status_op.py | 4 +- tests/plugin-unit-tests/test_refresh_token.py | 50 ++++ tests/set_airflow_env.sh | 3 + 12 files changed, 279 insertions(+), 176 deletions(-) delete mode 100644 scripts/convert_R2_schemas_to_R3.py create mode 100644 src/dags/libs/__init__.py create mode 100644 tests/plugin-unit-tests/test_refresh_token.py diff --git a/scripts/convert_R2_schemas_to_R3.py b/scripts/convert_R2_schemas_to_R3.py deleted file mode 100644 index fa8572f..0000000 --- a/scripts/convert_R2_schemas_to_R3.py +++ /dev/null @@ -1,81 +0,0 @@ -import glob -import json -import os -import re - -from collections import UserString - -TENANT = "opendes" -AUTHORITY = "osdu" -SCHEMAS_DIR = os.environ["SCHEMAS_DIR"] - - -class JsonString(UserString): - REF_REGEXP = r"(?P<abstract_perfix>\.\.\/abstract/)(?P<kind_name>\w+)\.(?P<version>\d+\.\d+\.\d+)\.json" - NAMESPACE_REGEXP = r"\<namespace\>" - - def repl_closure(self, match: re.match): - if not match.groups: - print(self.data) - raise Exception - kind_name = match.group('kind_name') - version = match.group('version') - repl = f"{TENANT}:{AUTHORITY}:{kind_name}:{version}" - return repl - - def replace_refs(self): - self.data = re.sub(self.REF_REGEXP, self.repl_closure, self.data) - return self - - def replace_namespaces(self): - self.data = re.sub(self.NAMESPACE_REGEXP, TENANT, self.data) - return self - - @staticmethod - def lower_first_letter(val: str): - if val[0].islower(): - pass - elif val in ( - "ACL", - "Legals", - "ID" - ): - val = val.lower() - else: - val = val.replace(val[0], val[0].lower(), 1) - return val - - def to_pascal_case(self): - tmp_properties = {} - tmp_required = [] - json_file_dict = json.loads(self.data) - try: - if "schemaInfo" in json_file_dict: # if schema has additional fields to be recorded - content = json_file_dict["schema"] - else: - content = json_file_dict - if "properties" in content: - for key, value in content["properties"].items(): - tmp_properties[self.lower_first_letter(key)] = value - content["properties"] = tmp_properties - if "required" in content: - for i in content["required"]: - tmp_required.append(self.lower_first_letter(i)) - content["required"] = tmp_required - self.data = json.dumps(json_file_dict, indent=4) - return self - except Exception as e: - print(self.data) - raise e - - -for file_path in glob.glob(SCHEMAS_DIR + "/*.json"): - try: - with open(file_path, "r") as file: - content = file.read() - content = JsonString(content).replace_refs().replace_namespaces().to_pascal_case().data - with open(file_path, "w") as file: - file.write(content) - except Exception as e: - print(f"Error on file {file_path}") - raise e diff --git a/src/dags/libs/__init__.py b/src/dags/libs/__init__.py new file mode 100644 index 0000000..5511adb --- /dev/null +++ b/src/dags/libs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/dags/libs/create_records.py b/src/dags/libs/create_records.py index b8446cd..76ae3c8 100644 --- a/src/dags/libs/create_records.py +++ b/src/dags/libs/create_records.py @@ -17,7 +17,7 @@ import configparser import logging from airflow.models import Variable -from libs.refresh_token import refresh_token +from libs.refresh_token import AirflowTokenRefresher, refresh_token from osdu_api.model.acl import Acl from osdu_api.model.legal import Legal from osdu_api.model.legal_compliance import LegalCompliance @@ -38,7 +38,7 @@ DEFAULT_SOURCE = config.get("DEFAULTS", "authority") DEFAULT_VERSION = config.get("DEFAULTS", "kind_version") -@refresh_token +@refresh_token(AirflowTokenRefresher()) def create_update_record_request(headers, record_client, record): resp = record_client.create_update_records([record], headers.items()) return resp diff --git a/src/dags/libs/exceptions.py b/src/dags/libs/exceptions.py index 4397814..c2649b2 100644 --- a/src/dags/libs/exceptions.py +++ b/src/dags/libs/exceptions.py @@ -46,3 +46,15 @@ class GetSchemaError(Exception): Raise when can't find schema. """ pass + +class NotOSDUShemaFormatError(Exception): + """ + Raise when schema doesn't correspond OSDU format + """ + pass + +class SAFilePathError(Exception): + """ + Raise when sa_file path is not specified in Env Variables. + """ + pass diff --git a/src/dags/libs/refresh_token.py b/src/dags/libs/refresh_token.py index 9341466..a02d6b8 100644 --- a/src/dags/libs/refresh_token.py +++ b/src/dags/libs/refresh_token.py @@ -13,21 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum +import json import logging +import os import sys -import time -from functools import partial +from typing import Callable, Union +from abc import ABC, abstractmethod from http import HTTPStatus +from urllib.parse import urlparse import requests -from airflow.models import Variable from google.auth.transport.requests import Request +from google.cloud import storage from google.oauth2 import service_account -from libs.exceptions import RefreshSATokenError +from libs.exceptions import RefreshSATokenError, SAFilePathError from tenacity import retry, stop_after_attempt -ACCESS_TOKEN = None - # Set up base logger handler = logging.StreamHandler(sys.stdout) handler.setFormatter( @@ -38,58 +40,136 @@ logger.addHandler(handler) RETRIES = 3 -SA_FILE_PATH = Variable.get("sa-file-osdu") -ACCESS_SCOPES = ['openid', 'email', 'profile'] +class TokenRefresher(ABC): -@retry(stop=stop_after_attempt(RETRIES)) -def get_access_token(sa_file: str, scopes: list) -> str: - """ - Refresh access token. - """ - try: - credentials = service_account.Credentials.from_service_account_file( - sa_file, scopes=scopes) - except ValueError as e: - logger.error("SA file has bad format.") - raise e - logger.info("Refresh token.") - credentials.refresh(Request()) - token = credentials.token - if credentials.token is None: - logger.error("Can't refresh token using SA-file. Token is empty.") - raise RefreshSATokenError - logger.info("Token refreshed.") - return token - - -@retry(stop=stop_after_attempt(RETRIES)) -def set_access_token(sa_file: str, scopes: list) -> str: - """ - Create token - """ - global ACCESS_TOKEN - token = get_access_token(sa_file, scopes) - auth = f"Bearer {token}" - ACCESS_TOKEN = token - Variable.set("access_token", token) - return auth + @abstractmethod + def refresh_token(self) -> str: + """ + Implement logics of refreshing token here. + """ + pass + + @property + @abstractmethod + def access_token(self) -> str: + pass + + @property + @abstractmethod + def authorization_header(self) -> dict: + """ + Must return authorization header for updating headers dict. + E.g. return {"Authorization": "Bearer <access_token>"} + """ + pass + + +class AirflowTokenRefresher(TokenRefresher): + DEFAULT_ACCESS_SCOPES = ['openid', 'email', 'profile'] + + def __init__(self, access_scopes: list=None): + from airflow.models import Variable + self.airflow_variables = Variable + self._access_token = None + self._access_scopes = access_scopes + @property + def access_scopes(self) -> list: + """ + Return access scopes. + Use DEFAULT_ACCESS_SCOPES if user-defined ones weren't provided. + """ + if not self._access_scopes: + self._access_scopes = self.DEFAULT_ACCESS_SCOPES + return self._access_scopes -def _check_token(): - global ACCESS_TOKEN - try: - if not ACCESS_TOKEN: - ACCESS_TOKEN = Variable.get('access_token') - except KeyError: - set_access_token(SA_FILE_PATH, ACCESS_SCOPES) + @staticmethod + @retry(stop=stop_after_attempt(RETRIES)) + def get_sa_file_content_from_google_storage(bucket_name: str, source_blob_name: str) -> str: + """ + Get sa_file content from Google Storage. + """ + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + logger.info("Got SA_file.") + return blob.download_as_string() + def get_sa_file_info(self) -> dict: + """ + Get file path from SA_FILE_PATH environmental variable. + This path can be GCS object URI or local file path. + Return content of sa path as dict. + """ + sa_file_path = os.environ.get("SA_FILE_PATH", None) + parsed_path = urlparse(sa_file_path) + if parsed_path.scheme == "gs": + bucket_name = parsed_path.netloc + source_blob_name = parsed_path.path[1:] # delete the first slash + sa_file_content = self.get_sa_file_content_from_google_storage(bucket_name, + source_blob_name) + sa_file_info = json.loads(sa_file_content) + elif not parsed_path.scheme and os.path.isfile(parsed_path.path): + with open(parsed_path.path) as f: + sa_file_info = json.load(f) + else: + raise SAFilePathError + return sa_file_info -def make_callable_request(obj, request_function, headers, *args, **kwargs): + @retry(stop=stop_after_attempt(RETRIES)) + def get_access_token_using_sa_file(self) -> str: + """ + Get new access token using SA info. + """ + sa_file_content = self.get_sa_file_info() + try: + credentials = service_account.Credentials.from_service_account_info( + sa_file_content, scopes=self.access_scopes) + except ValueError as e: + logger.error("SA file has bad format.") + raise e + + logger.info("Refresh token.") + credentials.refresh(Request()) + token = credentials.token + + if credentials.token is None: + logger.error("Can't refresh token using SA-file. Token is empty.") + raise RefreshSATokenError + + logger.info("Token refreshed.") + return token + + @retry(stop=stop_after_attempt(RETRIES)) + def refresh_token(self) -> str: + """ + Refresh token. Store its value in Airflow Variable. + """ + token = self.get_access_token_using_sa_file() + self.airflow_variables.set("access_token", token) + self._access_token = token + return self._access_token + + @property + def access_token(self) -> str: + if not self._access_token: + try: + self._access_token = self.airflow_variables.get("access_token") + except KeyError: + self.refresh_token() + return self._access_token + + @property + def authorization_header(self) -> dict: + return {"Authorization": f"Bearer {self.access_token}"} + + +def make_callable_request(obj: Union[object, None], request_function: Callable, headers: dict, + *args, **kwargs) -> Callable: """ Create send_request_with_auth function. """ - headers["Authorization"] = f"Bearer {ACCESS_TOKEN}" if obj: # if wrapped function is an object's method callable_request = lambda: request_function(obj, headers, *args, **kwargs) else: @@ -97,37 +177,58 @@ def make_callable_request(obj, request_function, headers, *args, **kwargs): return callable_request -def _wrapper(*args, **kwargs): +def _validate_headers_type(headers: dict): + if not isinstance(headers, dict): + logger.error(f"Got headers {headers}") + raise TypeError + + +def _validate_response_type(response: requests.Response, request_function: Callable): + if not isinstance(response, requests.Response): + logger.error(f"Function or method {request_function}" + f" must return values of type 'requests.Response'. " + f"Got {type(response)} instead") + raise TypeError + + +def _validate_token_refresher_type(token_refresher: TokenRefresher): + if not isinstance(token_refresher, TokenRefresher): + raise TypeError( + f"Token refresher must be of type {TokenRefresher}. Got {type(token_refresher)}" + ) + + +def send_request_with_auth_header(token_refresher: TokenRefresher, *args, + **kwargs) -> requests.Response: """ - Generic decorator wrapper for checking token and refreshing it. + Send request with authorization token. If response status is in HTTPStatus.UNAUTHORIZED or + HTTPStatus.FORBIDDEN, then refresh token and send request once again. """ - _check_token() obj = kwargs.pop("obj") if kwargs.get("obj") else None - headers = kwargs.pop("headers") request_function = kwargs.pop("request_function") - if not isinstance(headers, dict): - logger.error("Got headers %s" % headers) - raise TypeError + headers = kwargs.pop("headers") + _validate_headers_type(headers) + headers.update(token_refresher.authorization_header) + send_request_with_auth = make_callable_request(obj, request_function, headers, *args, **kwargs) response = send_request_with_auth() - if not isinstance(response, requests.Response): - logger.error("Function %s must return values of type requests.Response. " - "Got %s instead" % (request_function, type(response))) - raise TypeError + _validate_response_type(response, request_function) + if not response.ok: if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): - set_access_token(SA_FILE_PATH, ACCESS_SCOPES) - send_request_with_auth = make_callable_request(obj, - request_function, - headers, - *args, **kwargs) + token_refresher.refresh_token() + headers.update(token_refresher.authorization_header) + send_request_with_auth = make_callable_request(obj, + request_function, + headers, + *args, **kwargs) response = send_request_with_auth() response.raise_for_status() return response -def refresh_token(request_function): +def refresh_token(token_refresher: TokenRefresher) -> Callable: """ Wrap a request function and check response. If response's error status code is about Authorization, refresh token and invoke this function once again. @@ -137,12 +238,24 @@ def refresh_token(request_function): Or method: request_method(self, header: dict, *args, **kwargs) -> requests.Response """ - is_method = len(request_function.__qualname__.split(".")) > 1 - if is_method: - def wrapper(obj, headers, *args, **kwargs): - return _wrapper(request_function=request_function, obj=obj, headers=headers, *args, - **kwargs) - else: - def wrapper(headers, *args, **kwargs): - return _wrapper(request_function=request_function, headers=headers, *args, **kwargs) - return wrapper + _validate_token_refresher_type(token_refresher) + + def refresh_token_wrapper(request_function: Callable) -> Callable: + is_method = len(request_function.__qualname__.split(".")) > 1 + if is_method: + def _wrapper(obj: object, headers: dict, *args, **kwargs) -> requests.Response: + return send_request_with_auth_header(token_refresher, + request_function=request_function, + obj=obj, + headers=headers, + *args, + **kwargs) + else: + def _wrapper(headers: dict, *args, **kwargs) -> requests.Response: + return send_request_with_auth_header(token_refresher, + request_function=request_function, + headers=headers, + *args, **kwargs) + return _wrapper + + return refresh_token_wrapper diff --git a/src/dags/osdu-ingest-r3.py b/src/dags/osdu-ingest-r3.py index e0292ca..674150c 100644 --- a/src/dags/osdu-ingest-r3.py +++ b/src/dags/osdu-ingest-r3.py @@ -55,12 +55,4 @@ process_manifest_op = ProcessManifestOperatorR3( dag=dag ) -search_record_ids_op = SearchRecordIdOperator( - task_id="search_record_ids_task", - provide_context=True, - dag=dag, - retries=4 -) - -update_status_running_op >> process_manifest_op >> \ -search_record_ids_op >> update_status_finished_op # pylint: disable=pointless-statement +update_status_running_op >> process_manifest_op >> update_status_finished_op # pylint: disable=pointless-statement diff --git a/src/plugins/operators/process_manifest_r2_op.py b/src/plugins/operators/process_manifest_r2_op.py index 89ddc08..e50ca8f 100644 --- a/src/plugins/operators/process_manifest_r2_op.py +++ b/src/plugins/operators/process_manifest_r2_op.py @@ -28,7 +28,7 @@ from urllib.error import HTTPError import requests from airflow.models import BaseOperator, Variable -from libs.refresh_token import refresh_token +from libs.refresh_token import AirflowTokenRefresher, refresh_token ACL_DICT = eval(Variable.get("acl")) LEGAL_DICT = eval(Variable.get("legal")) @@ -271,7 +271,7 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp return data_objects_list -@refresh_token +@refresh_token(AirflowTokenRefresher()) def send_request(headers, request_data): """ Send request to records storage API diff --git a/src/plugins/operators/process_manifest_r3.py b/src/plugins/operators/process_manifest_r3.py index b0323cd..f4fb2f4 100644 --- a/src/plugins/operators/process_manifest_r3.py +++ b/src/plugins/operators/process_manifest_r3.py @@ -26,7 +26,7 @@ import requests import tenacity from airflow.models import BaseOperator, Variable from libs.exceptions import EmptyManifestError, GetSchemaError -from libs.refresh_token import refresh_token +from libs.refresh_token import AirflowTokenRefresher, refresh_token # Set up base logger handler = logging.StreamHandler(sys.stdout) @@ -104,7 +104,7 @@ class SchemaValidator(object): 'AppKey': self.context.app_key, } - @refresh_token + @refresh_token(AirflowTokenRefresher()) def _get_schema_request(self, headers, uri): response = requests.get(uri, headers=headers, timeout=60) return response @@ -176,7 +176,7 @@ class ManifestProcessor(object): @tenacity.retry(tenacity.wait_fixed(TIMEOUT), tenacity.stop_after_attempt(RETRIES)) - @refresh_token + @refresh_token(AirflowTokenRefresher()) def save_record(self, headers, request_data): """ Send request to record storage API diff --git a/src/plugins/operators/search_record_id_op.py b/src/plugins/operators/search_record_id_op.py index 700afc7..ef8faed 100644 --- a/src/plugins/operators/search_record_id_op.py +++ b/src/plugins/operators/search_record_id_op.py @@ -24,7 +24,7 @@ from airflow.models import BaseOperator, Variable from airflow.utils.decorators import apply_defaults from hooks import search_http_hook, workflow_hook from libs.exceptions import RecordsNotSearchableError -from libs.refresh_token import refresh_token +from libs.refresh_token import AirflowTokenRefresher, refresh_token # Set up base logger handler = logging.StreamHandler(sys.stdout) @@ -99,7 +99,7 @@ class SearchRecordIdOperator(BaseOperator): data = resp.json() return data.get("totalCount") == self.expected_total_count - @refresh_token + @refresh_token(AirflowTokenRefresher()) def search_files(self, headers, **kwargs): if self.request_body: response = self.search_hook.run( diff --git a/src/plugins/operators/update_status_op.py b/src/plugins/operators/update_status_op.py index b02f678..145911b 100644 --- a/src/plugins/operators/update_status_op.py +++ b/src/plugins/operators/update_status_op.py @@ -25,7 +25,7 @@ from airflow.models import BaseOperator, Variable from airflow.utils.decorators import apply_defaults from hooks.http_hooks import search_http_hook, workflow_hook from libs.exceptions import PipelineFailedError -from libs.refresh_token import refresh_token +from libs.refresh_token import AirflowTokenRefresher, refresh_token # Set up base logger handler = logging.StreamHandler(sys.stdout) @@ -141,7 +141,7 @@ class UpdateStatusOperator(BaseOperator): if self.status == self.FAILED_STATUS: raise PipelineFailedError("Dag failed") - @refresh_token + @refresh_token(AirflowTokenRefresher()) def update_status_request(self, headers, status, **kwargs): data_conf = kwargs['dag_run'].conf logger.info(f"Got dataconf {data_conf}") diff --git a/tests/plugin-unit-tests/test_refresh_token.py b/tests/plugin-unit-tests/test_refresh_token.py new file mode 100644 index 0000000..b5e976a --- /dev/null +++ b/tests/plugin-unit-tests/test_refresh_token.py @@ -0,0 +1,50 @@ +import os +import sys +from unittest.mock import MagicMock + +import pytest + +sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") + +from libs.refresh_token import AirflowTokenRefresher +from libs.exceptions import SAFilePathError + + +def create_token_refresher() -> AirflowTokenRefresher: + token_refresher = AirflowTokenRefresher() + return token_refresher + + +@pytest.mark.parametrize( + "access_token", + [ + "test" + ] +) +def test_access_token_cached(access_token: str): + """ + Check if access token stored in Airflow Variables after refreshing it. + """ + token_refresher = create_token_refresher() + token_refresher.get_access_token_using_sa_file = MagicMock(return_value=access_token) + token_refresher.refresh_token() + assert token_refresher.airflow_variables.get("access_token") == access_token + + +@pytest.mark.parametrize( + "access_token", + [ + "test" + ] +) +def test_authorization_header(access_token: str): + token_refresher = create_token_refresher() + token_refresher.get_access_token_using_sa_file = MagicMock(return_value=access_token) + token_refresher.refresh_token() + assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}" + + +def test_raise_sa_path_error_on_getting_absent_sa_file(): + token_refresher = create_token_refresher() + with pytest.raises(SAFilePathError): + token_refresher.get_sa_file_info() diff --git a/tests/set_airflow_env.sh b/tests/set_airflow_env.sh index 399e1a1..ea7a432 100755 --- a/tests/set_airflow_env.sh +++ b/tests/set_airflow_env.sh @@ -16,6 +16,8 @@ pip install --upgrade google-api-python-client pip install dataclasses pip install jsonschema +pip install google +pip install google-cloud-storage export ACL='{"viewers": ["foo"],"owners": ["foo"]}' export LEGAL='{"legaltags": ["foo"], "otherRelevantDataCountries": ["FR", "US", "CA"],"status": "compliant"}' export WORKFLOW_URL="http://127.0.0.1:5000/wf" @@ -24,6 +26,7 @@ export LOCALHOST="http://127.0.0.1:5000" export SEARCH_CONN_ID="http://127.0.0.1:5000" export WORKFLOW_CONN_ID="http://127.0.0.1:5000" export DATALOAD_CONFIG_PATH="/usr/local/airflow/dags/configs/dataload.ini" +export SA_FILE_PATH="test" airflow initdb > /dev/null 2>&1 -- GitLab