Skip to content
Snippets Groups Projects
Commit 543b4cca authored by Siarhei Khaletski (EPAM)'s avatar Siarhei Khaletski (EPAM) :triangular_flag_on_post:
Browse files

Merge branch 'GONRG-1180_SA_file_for_access_token' into 'integration-master'

GONRG-1180: Add refresh token using bucket

See merge request go3-nrg/platform/data-flow/ingestion/ingestion-dags!16
parents f2a67962 43c5ff47
No related branches found
No related tags found
1 merge request!6R3 Data Ingestion
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -17,7 +17,7 @@ import configparser
import logging
from airflow.models import Variable
from libs.refresh_token import refresh_token
from libs.refresh_token import AirflowTokenRefresher, refresh_token
from osdu_api.model.acl import Acl
from osdu_api.model.legal import Legal
from osdu_api.model.legal_compliance import LegalCompliance
......@@ -38,7 +38,7 @@ DEFAULT_SOURCE = config.get("DEFAULTS", "authority")
DEFAULT_VERSION = config.get("DEFAULTS", "kind_version")
@refresh_token
@refresh_token(AirflowTokenRefresher())
def create_update_record_request(headers, record_client, record):
resp = record_client.create_update_records([record], headers.items())
return resp
......
......@@ -46,3 +46,15 @@ class GetSchemaError(Exception):
Raise when can't find schema.
"""
pass
class NotOSDUShemaFormatError(Exception):
"""
Raise when schema doesn't correspond OSDU format
"""
pass
class SAFilePathError(Exception):
"""
Raise when sa_file path is not specified in Env Variables.
"""
pass
......@@ -13,21 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import json
import logging
import os
import sys
import time
from functools import partial
from typing import Callable, Union
from abc import ABC, abstractmethod
from http import HTTPStatus
from urllib.parse import urlparse
import requests
from airflow.models import Variable
from google.auth.transport.requests import Request
from google.cloud import storage
from google.oauth2 import service_account
from libs.exceptions import RefreshSATokenError
from libs.exceptions import RefreshSATokenError, SAFilePathError
from tenacity import retry, stop_after_attempt
ACCESS_TOKEN = None
# Set up base logger
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
......@@ -38,58 +40,136 @@ logger.addHandler(handler)
RETRIES = 3
SA_FILE_PATH = Variable.get("sa-file-osdu")
ACCESS_SCOPES = ['openid', 'email', 'profile']
class TokenRefresher(ABC):
@retry(stop=stop_after_attempt(RETRIES))
def get_access_token(sa_file: str, scopes: list) -> str:
"""
Refresh access token.
"""
try:
credentials = service_account.Credentials.from_service_account_file(
sa_file, scopes=scopes)
except ValueError as e:
logger.error("SA file has bad format.")
raise e
logger.info("Refresh token.")
credentials.refresh(Request())
token = credentials.token
if credentials.token is None:
logger.error("Can't refresh token using SA-file. Token is empty.")
raise RefreshSATokenError
logger.info("Token refreshed.")
return token
@retry(stop=stop_after_attempt(RETRIES))
def set_access_token(sa_file: str, scopes: list) -> str:
"""
Create token
"""
global ACCESS_TOKEN
token = get_access_token(sa_file, scopes)
auth = f"Bearer {token}"
ACCESS_TOKEN = token
Variable.set("access_token", token)
return auth
@abstractmethod
def refresh_token(self) -> str:
"""
Implement logics of refreshing token here.
"""
pass
@property
@abstractmethod
def access_token(self) -> str:
pass
@property
@abstractmethod
def authorization_header(self) -> dict:
"""
Must return authorization header for updating headers dict.
E.g. return {"Authorization": "Bearer <access_token>"}
"""
pass
class AirflowTokenRefresher(TokenRefresher):
DEFAULT_ACCESS_SCOPES = ['openid', 'email', 'profile']
def __init__(self, access_scopes: list=None):
from airflow.models import Variable
self.airflow_variables = Variable
self._access_token = None
self._access_scopes = access_scopes
@property
def access_scopes(self) -> list:
"""
Return access scopes.
Use DEFAULT_ACCESS_SCOPES if user-defined ones weren't provided.
"""
if not self._access_scopes:
self._access_scopes = self.DEFAULT_ACCESS_SCOPES
return self._access_scopes
def _check_token():
global ACCESS_TOKEN
try:
if not ACCESS_TOKEN:
ACCESS_TOKEN = Variable.get('access_token')
except KeyError:
set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
@staticmethod
@retry(stop=stop_after_attempt(RETRIES))
def get_sa_file_content_from_google_storage(bucket_name: str, source_blob_name: str) -> str:
"""
Get sa_file content from Google Storage.
"""
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(source_blob_name)
logger.info("Got SA_file.")
return blob.download_as_string()
def get_sa_file_info(self) -> dict:
"""
Get file path from SA_FILE_PATH environmental variable.
This path can be GCS object URI or local file path.
Return content of sa path as dict.
"""
sa_file_path = os.environ.get("SA_FILE_PATH", None)
parsed_path = urlparse(sa_file_path)
if parsed_path.scheme == "gs":
bucket_name = parsed_path.netloc
source_blob_name = parsed_path.path[1:] # delete the first slash
sa_file_content = self.get_sa_file_content_from_google_storage(bucket_name,
source_blob_name)
sa_file_info = json.loads(sa_file_content)
elif not parsed_path.scheme and os.path.isfile(parsed_path.path):
with open(parsed_path.path) as f:
sa_file_info = json.load(f)
else:
raise SAFilePathError
return sa_file_info
def make_callable_request(obj, request_function, headers, *args, **kwargs):
@retry(stop=stop_after_attempt(RETRIES))
def get_access_token_using_sa_file(self) -> str:
"""
Get new access token using SA info.
"""
sa_file_content = self.get_sa_file_info()
try:
credentials = service_account.Credentials.from_service_account_info(
sa_file_content, scopes=self.access_scopes)
except ValueError as e:
logger.error("SA file has bad format.")
raise e
logger.info("Refresh token.")
credentials.refresh(Request())
token = credentials.token
if credentials.token is None:
logger.error("Can't refresh token using SA-file. Token is empty.")
raise RefreshSATokenError
logger.info("Token refreshed.")
return token
@retry(stop=stop_after_attempt(RETRIES))
def refresh_token(self) -> str:
"""
Refresh token. Store its value in Airflow Variable.
"""
token = self.get_access_token_using_sa_file()
self.airflow_variables.set("access_token", token)
self._access_token = token
return self._access_token
@property
def access_token(self) -> str:
if not self._access_token:
try:
self._access_token = self.airflow_variables.get("access_token")
except KeyError:
self.refresh_token()
return self._access_token
@property
def authorization_header(self) -> dict:
return {"Authorization": f"Bearer {self.access_token}"}
def make_callable_request(obj: Union[object, None], request_function: Callable, headers: dict,
*args, **kwargs) -> Callable:
"""
Create send_request_with_auth function.
"""
headers["Authorization"] = f"Bearer {ACCESS_TOKEN}"
if obj: # if wrapped function is an object's method
callable_request = lambda: request_function(obj, headers, *args, **kwargs)
else:
......@@ -97,37 +177,58 @@ def make_callable_request(obj, request_function, headers, *args, **kwargs):
return callable_request
def _wrapper(*args, **kwargs):
def _validate_headers_type(headers: dict):
if not isinstance(headers, dict):
logger.error(f"Got headers {headers}")
raise TypeError
def _validate_response_type(response: requests.Response, request_function: Callable):
if not isinstance(response, requests.Response):
logger.error(f"Function or method {request_function}"
f" must return values of type 'requests.Response'. "
f"Got {type(response)} instead")
raise TypeError
def _validate_token_refresher_type(token_refresher: TokenRefresher):
if not isinstance(token_refresher, TokenRefresher):
raise TypeError(
f"Token refresher must be of type {TokenRefresher}. Got {type(token_refresher)}"
)
def send_request_with_auth_header(token_refresher: TokenRefresher, *args,
**kwargs) -> requests.Response:
"""
Generic decorator wrapper for checking token and refreshing it.
Send request with authorization token. If response status is in HTTPStatus.UNAUTHORIZED or
HTTPStatus.FORBIDDEN, then refresh token and send request once again.
"""
_check_token()
obj = kwargs.pop("obj") if kwargs.get("obj") else None
headers = kwargs.pop("headers")
request_function = kwargs.pop("request_function")
if not isinstance(headers, dict):
logger.error("Got headers %s" % headers)
raise TypeError
headers = kwargs.pop("headers")
_validate_headers_type(headers)
headers.update(token_refresher.authorization_header)
send_request_with_auth = make_callable_request(obj, request_function, headers,
*args, **kwargs)
response = send_request_with_auth()
if not isinstance(response, requests.Response):
logger.error("Function %s must return values of type requests.Response. "
"Got %s instead" % (request_function, type(response)))
raise TypeError
_validate_response_type(response, request_function)
if not response.ok:
if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
send_request_with_auth = make_callable_request(obj,
request_function,
headers,
*args, **kwargs)
token_refresher.refresh_token()
headers.update(token_refresher.authorization_header)
send_request_with_auth = make_callable_request(obj,
request_function,
headers,
*args, **kwargs)
response = send_request_with_auth()
response.raise_for_status()
return response
def refresh_token(request_function):
def refresh_token(token_refresher: TokenRefresher) -> Callable:
"""
Wrap a request function and check response. If response's error status code
is about Authorization, refresh token and invoke this function once again.
......@@ -137,12 +238,24 @@ def refresh_token(request_function):
Or method:
request_method(self, header: dict, *args, **kwargs) -> requests.Response
"""
is_method = len(request_function.__qualname__.split(".")) > 1
if is_method:
def wrapper(obj, headers, *args, **kwargs):
return _wrapper(request_function=request_function, obj=obj, headers=headers, *args,
**kwargs)
else:
def wrapper(headers, *args, **kwargs):
return _wrapper(request_function=request_function, headers=headers, *args, **kwargs)
return wrapper
_validate_token_refresher_type(token_refresher)
def refresh_token_wrapper(request_function: Callable) -> Callable:
is_method = len(request_function.__qualname__.split(".")) > 1
if is_method:
def _wrapper(obj: object, headers: dict, *args, **kwargs) -> requests.Response:
return send_request_with_auth_header(token_refresher,
request_function=request_function,
obj=obj,
headers=headers,
*args,
**kwargs)
else:
def _wrapper(headers: dict, *args, **kwargs) -> requests.Response:
return send_request_with_auth_header(token_refresher,
request_function=request_function,
headers=headers,
*args, **kwargs)
return _wrapper
return refresh_token_wrapper
......@@ -55,12 +55,4 @@ process_manifest_op = ProcessManifestOperatorR3(
dag=dag
)
search_record_ids_op = SearchRecordIdOperator(
task_id="search_record_ids_task",
provide_context=True,
dag=dag,
retries=4
)
update_status_running_op >> process_manifest_op >> \
search_record_ids_op >> update_status_finished_op # pylint: disable=pointless-statement
update_status_running_op >> process_manifest_op >> update_status_finished_op # pylint: disable=pointless-statement
......@@ -28,7 +28,7 @@ from urllib.error import HTTPError
import requests
from airflow.models import BaseOperator, Variable
from libs.refresh_token import refresh_token
from libs.refresh_token import AirflowTokenRefresher, refresh_token
ACL_DICT = eval(Variable.get("acl"))
LEGAL_DICT = eval(Variable.get("legal"))
......@@ -271,7 +271,7 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp
return data_objects_list
@refresh_token
@refresh_token(AirflowTokenRefresher())
def send_request(headers, request_data):
"""
Send request to records storage API
......
......@@ -26,7 +26,7 @@ import requests
import tenacity
from airflow.models import BaseOperator, Variable
from libs.exceptions import EmptyManifestError, GetSchemaError
from libs.refresh_token import refresh_token
from libs.refresh_token import AirflowTokenRefresher, refresh_token
# Set up base logger
handler = logging.StreamHandler(sys.stdout)
......@@ -104,7 +104,7 @@ class SchemaValidator(object):
'AppKey': self.context.app_key,
}
@refresh_token
@refresh_token(AirflowTokenRefresher())
def _get_schema_request(self, headers, uri):
response = requests.get(uri, headers=headers, timeout=60)
return response
......@@ -176,7 +176,7 @@ class ManifestProcessor(object):
@tenacity.retry(tenacity.wait_fixed(TIMEOUT),
tenacity.stop_after_attempt(RETRIES))
@refresh_token
@refresh_token(AirflowTokenRefresher())
def save_record(self, headers, request_data):
"""
Send request to record storage API
......
......@@ -24,7 +24,7 @@ from airflow.models import BaseOperator, Variable
from airflow.utils.decorators import apply_defaults
from hooks import search_http_hook, workflow_hook
from libs.exceptions import RecordsNotSearchableError
from libs.refresh_token import refresh_token
from libs.refresh_token import AirflowTokenRefresher, refresh_token
# Set up base logger
handler = logging.StreamHandler(sys.stdout)
......@@ -99,7 +99,7 @@ class SearchRecordIdOperator(BaseOperator):
data = resp.json()
return data.get("totalCount") == self.expected_total_count
@refresh_token
@refresh_token(AirflowTokenRefresher())
def search_files(self, headers, **kwargs):
if self.request_body:
response = self.search_hook.run(
......
......@@ -25,7 +25,7 @@ from airflow.models import BaseOperator, Variable
from airflow.utils.decorators import apply_defaults
from hooks.http_hooks import search_http_hook, workflow_hook
from libs.exceptions import PipelineFailedError
from libs.refresh_token import refresh_token
from libs.refresh_token import AirflowTokenRefresher, refresh_token
# Set up base logger
handler = logging.StreamHandler(sys.stdout)
......@@ -141,7 +141,7 @@ class UpdateStatusOperator(BaseOperator):
if self.status == self.FAILED_STATUS:
raise PipelineFailedError("Dag failed")
@refresh_token
@refresh_token(AirflowTokenRefresher())
def update_status_request(self, headers, status, **kwargs):
data_conf = kwargs['dag_run'].conf
logger.info(f"Got dataconf {data_conf}")
......
import os
import sys
from unittest.mock import MagicMock
import pytest
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")
from libs.refresh_token import AirflowTokenRefresher
from libs.exceptions import SAFilePathError
def create_token_refresher() -> AirflowTokenRefresher:
token_refresher = AirflowTokenRefresher()
return token_refresher
@pytest.mark.parametrize(
"access_token",
[
"test"
]
)
def test_access_token_cached(access_token: str):
"""
Check if access token stored in Airflow Variables after refreshing it.
"""
token_refresher = create_token_refresher()
token_refresher.get_access_token_using_sa_file = MagicMock(return_value=access_token)
token_refresher.refresh_token()
assert token_refresher.airflow_variables.get("access_token") == access_token
@pytest.mark.parametrize(
"access_token",
[
"test"
]
)
def test_authorization_header(access_token: str):
token_refresher = create_token_refresher()
token_refresher.get_access_token_using_sa_file = MagicMock(return_value=access_token)
token_refresher.refresh_token()
assert token_refresher.authorization_header.get("Authorization") == f"Bearer {access_token}"
def test_raise_sa_path_error_on_getting_absent_sa_file():
token_refresher = create_token_refresher()
with pytest.raises(SAFilePathError):
token_refresher.get_sa_file_info()
......@@ -16,6 +16,8 @@
pip install --upgrade google-api-python-client
pip install dataclasses
pip install jsonschema
pip install google
pip install google-cloud-storage
export ACL='{"viewers": ["foo"],"owners": ["foo"]}'
export LEGAL='{"legaltags": ["foo"], "otherRelevantDataCountries": ["FR", "US", "CA"],"status": "compliant"}'
export WORKFLOW_URL="http://127.0.0.1:5000/wf"
......@@ -24,6 +26,7 @@ export LOCALHOST="http://127.0.0.1:5000"
export SEARCH_CONN_ID="http://127.0.0.1:5000"
export WORKFLOW_CONN_ID="http://127.0.0.1:5000"
export DATALOAD_CONFIG_PATH="/usr/local/airflow/dags/configs/dataload.ini"
export SA_FILE_PATH="test"
airflow initdb > /dev/null 2>&1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment