Skip to content
Snippets Groups Projects
Commit f0cd1baf authored by Siarhei Khaletski (EPAM)'s avatar Siarhei Khaletski (EPAM) :triangular_flag_on_post:
Browse files

Merge branch 'GONRG-790_Access_token_using_SA' into 'integration-master'

GONRG-790 "Access token using sa"

See merge request go3-nrg/platform/data-flow/ingestion/ingestion-dags!12
parents 3b8afceb 710b8152
No related branches found
No related tags found
1 merge request!6R3 Data Ingestion
...@@ -17,6 +17,7 @@ import configparser ...@@ -17,6 +17,7 @@ import configparser
import logging import logging
from airflow.models import Variable from airflow.models import Variable
from libs.refresh_token import refresh_token
from osdu_api.model.acl import Acl from osdu_api.model.acl import Acl
from osdu_api.model.legal import Legal from osdu_api.model.legal import Legal
from osdu_api.model.legal_compliance import LegalCompliance from osdu_api.model.legal_compliance import LegalCompliance
...@@ -26,7 +27,6 @@ from osdu_api.storage.record_client import RecordClient ...@@ -26,7 +27,6 @@ from osdu_api.storage.record_client import RecordClient
logger = logging.getLogger() logger = logging.getLogger()
ACL_DICT = eval(Variable.get("acl")) ACL_DICT = eval(Variable.get("acl"))
LEGAL_DICT = eval(Variable.get("legal")) LEGAL_DICT = eval(Variable.get("legal"))
...@@ -38,6 +38,12 @@ DEFAULT_SOURCE = config.get("DEFAULTS", "authority") ...@@ -38,6 +38,12 @@ DEFAULT_SOURCE = config.get("DEFAULTS", "authority")
DEFAULT_VERSION = config.get("DEFAULTS", "kind_version") DEFAULT_VERSION = config.get("DEFAULTS", "kind_version")
@refresh_token
def create_update_record_request(headers, record_client, record):
resp = record_client.create_update_records([record], headers.items())
return resp
def create_records(**kwargs): def create_records(**kwargs):
# the only way to pass in values through the experimental api is through # the only way to pass in values through the experimental api is through
# the conf parameter # the conf parameter
...@@ -69,14 +75,13 @@ def create_records(**kwargs): ...@@ -69,14 +75,13 @@ def create_records(**kwargs):
headers = { headers = {
"content-type": "application/json", "content-type": "application/json",
"slb-data-partition-id": data_conf.get("partition-id", DEFAULT_SOURCE), "slb-data-partition-id": data_conf.get("partition-id", DEFAULT_SOURCE),
"Authorization": f"{auth}",
"AppKey": data_conf.get("app-key", "") "AppKey": data_conf.get("app-key", "")
} }
record_client = RecordClient() record_client = RecordClient()
record_client.data_partition_id = data_conf.get( record_client.data_partition_id = data_conf.get(
"partition-id", DEFAULT_SOURCE) "partition-id", DEFAULT_SOURCE)
resp = record_client.create_update_records([record], headers.items()) resp = create_update_record_request(headers, record_client, record)
logger.info(f"Response: {resp.text}") logger.info(f"Response: {resp.text}")
kwargs["ti"].xcom_push(key="record_ids", value=resp.json()["recordIds"]) kwargs["ti"].xcom_push(key="record_ids", value=resp.json()["recordIds"])
return {"response_status": resp.status_code} return {"response_status": resp.status_code}
...@@ -58,7 +58,8 @@ process_manifest_op = ProcessManifestOperator( ...@@ -58,7 +58,8 @@ process_manifest_op = ProcessManifestOperator(
search_record_ids_op = SearchRecordIdOperator( search_record_ids_op = SearchRecordIdOperator(
task_id="search_record_ids_task", task_id="search_record_ids_task",
provide_context=True, provide_context=True,
dag=dag dag=dag,
retries=1
) )
update_status_running_op >> process_manifest_op >> \ update_status_running_op >> process_manifest_op >> \
......
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class RecordsNotSearchableError(Exception):
"""
Raise when expected totalCount of records differs from actual one.
"""
pass
class RefreshSATokenError(Exception):
"""
Raise when token is empty after attempt to get credentials from service account file.
"""
pass
class PipelineFailedError(Exception):
"""
Raise when pipeline failed.
"""
pass
# Copyright 2020 Google LLC
# Copyright 2020 EPAM Systems
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from http import HTTPStatus
import requests
from airflow.models import Variable
from google.auth.transport.requests import Request
from google.oauth2 import service_account
from libs.exceptions import RefreshSATokenError
from tenacity import retry, stop_after_attempt
ACCESS_TOKEN = None
# Set up base logger
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s"))
logger = logging.getLogger("Dataload")
logger.setLevel(logging.INFO)
logger.addHandler(handler)
RETRIES = 3
SA_FILE_PATH = Variable.get("sa-file-osdu")
ACCESS_SCOPES = ['openid', 'email', 'profile']
@retry(stop=stop_after_attempt(RETRIES))
def get_access_token(sa_file: str, scopes: list) -> str:
"""
Refresh access token.
"""
try:
credentials = service_account.Credentials.from_service_account_file(
sa_file, scopes=scopes)
except ValueError as e:
logger.error("SA file has bad format.")
raise e
logger.info("Refresh token.")
credentials.refresh(Request())
token = credentials.token
if credentials.token is None:
logger.error("Can't refresh token using SA-file. Token is empty.")
raise RefreshSATokenError
logger.info("Token refreshed.")
return token
@retry(stop=stop_after_attempt(RETRIES))
def set_access_token(sa_file: str, scopes: list) -> str:
"""
Create token
"""
global ACCESS_TOKEN
token = get_access_token(sa_file, scopes)
auth = f"Bearer {token}"
ACCESS_TOKEN = token
Variable.set("access_token", token)
return auth
def _check_token():
global ACCESS_TOKEN
try:
if not ACCESS_TOKEN:
ACCESS_TOKEN = Variable.get('access_token')
except KeyError:
set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
def _wrapper(*args, **kwargs):
"""
Generic decorator wrapper for checking token and refreshing it.
"""
_check_token()
obj = kwargs.pop("obj") if kwargs.get("obj") else None
headers = kwargs.pop("headers")
request_function = kwargs.pop("request_function")
if not isinstance(headers, dict):
logger.error("Got headers %s" % headers)
raise TypeError
headers["Authorization"] = f"Bearer {ACCESS_TOKEN}"
if obj: # if wrapped function is an object's method
send_request_with_auth = lambda: request_function(obj, headers, *args, **kwargs)
else:
send_request_with_auth = lambda: request_function(headers, *args, **kwargs)
response = send_request_with_auth()
if not isinstance(response, requests.Response):
logger.error("Function %s must return values of type requests.Response. "
"Got %s instead" % (kwargs["rqst_func"], type(response)))
raise TypeError
if not response.ok:
if response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
set_access_token(SA_FILE_PATH, ACCESS_SCOPES)
response = send_request_with_auth()
else:
response.raise_for_status()
return response
def refresh_token(request_function):
"""
Wrap a request function and check response. If response's error status code
is about Authorization, refresh token and invoke this function once again.
Expects function:
If response is not ok and not about Authorization, then raises HTTPError
request_func(header: dict, *args, **kwargs) -> requests.Response
Or method:
request_method(self, header: dict, *args, **kwargs) -> requests.Response
"""
is_method = len(request_function.__qualname__.split(".")) > 1
if is_method:
def wrapper(obj, headers, *args, **kwargs):
return _wrapper(request_function=request_function, obj=obj, headers=headers, *args,
**kwargs)
else:
def wrapper(headers, *args, **kwargs):
return _wrapper(request_function=request_function, headers=headers, *args, **kwargs)
return wrapper
...@@ -28,6 +28,7 @@ from urllib.error import HTTPError ...@@ -28,6 +28,7 @@ from urllib.error import HTTPError
import requests import requests
from airflow.models import BaseOperator, Variable from airflow.models import BaseOperator, Variable
from libs.refresh_token import refresh_token
ACL_DICT = eval(Variable.get("acl")) ACL_DICT = eval(Variable.get("acl"))
LEGAL_DICT = eval(Variable.get("legal")) LEGAL_DICT = eval(Variable.get("legal"))
...@@ -44,7 +45,8 @@ TIMEOUT = 1 ...@@ -44,7 +45,8 @@ TIMEOUT = 1
# Set up base logger # Set up base logger
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) handler.setFormatter(
logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s"))
logger = logging.getLogger("Dataload") logger = logging.getLogger("Dataload")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(handler) logger.addHandler(handler)
...@@ -69,6 +71,7 @@ class FileType(enum.Enum): ...@@ -69,6 +71,7 @@ class FileType(enum.Enum):
MANIFEST = enum.auto() MANIFEST = enum.auto()
WORKPRODUCT = enum.auto() WORKPRODUCT = enum.auto()
def dataload(**kwargs): def dataload(**kwargs):
data_conf = kwargs['dag_run'].conf data_conf = kwargs['dag_run'].conf
loaded_conf = { loaded_conf = {
...@@ -80,6 +83,17 @@ def dataload(**kwargs): ...@@ -80,6 +83,17 @@ def dataload(**kwargs):
return loaded_conf, conf_payload return loaded_conf, conf_payload
def create_headers(conf_payload):
partition_id = conf_payload["data-partition-id"]
app_key = conf_payload["AppKey"]
headers = {
'Content-type': 'application/json',
'data-partition-id': partition_id,
'AppKey': app_key
}
return headers
def generate_id(type_id): def generate_id(type_id):
""" """
Generate resource ID Generate resource ID
...@@ -100,6 +114,7 @@ def determine_data_type(raw_resource_type_id): ...@@ -100,6 +114,7 @@ def determine_data_type(raw_resource_type_id):
return raw_resource_type_id.split("/")[-1].replace(":", "") \ return raw_resource_type_id.split("/")[-1].replace(":", "") \
if raw_resource_type_id is not None else None if raw_resource_type_id is not None else None
# TODO: add comments to functions that implement actions in this function # TODO: add comments to functions that implement actions in this function
def process_file_items(loaded_conf, conf_payload) -> Tuple[list, list]: def process_file_items(loaded_conf, conf_payload) -> Tuple[list, list]:
file_ids = [] file_ids = []
...@@ -118,6 +133,7 @@ def process_file_items(loaded_conf, conf_payload) -> Tuple[list, list]: ...@@ -118,6 +133,7 @@ def process_file_items(loaded_conf, conf_payload) -> Tuple[list, list]:
) )
return file_list, file_ids return file_list, file_ids
def process_wpc_items(loaded_conf, product_type, file_ids, conf_payload): def process_wpc_items(loaded_conf, product_type, file_ids, conf_payload):
wpc_ids = [] wpc_ids = []
wpc_list = [] wpc_list = []
...@@ -153,6 +169,7 @@ def process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload) -> list: ...@@ -153,6 +169,7 @@ def process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload) -> list:
] ]
return work_product return work_product
def validate_file_type(file_type, data_object): def validate_file_type(file_type, data_object):
if not file_type: if not file_type:
logger.error(f"Error with file {data_object}. Type could not be specified.") logger.error(f"Error with file {data_object}. Type could not be specified.")
...@@ -172,8 +189,8 @@ def validate_file(loaded_conf) -> Tuple[FileType, str]: ...@@ -172,8 +189,8 @@ def validate_file(loaded_conf) -> Tuple[FileType, str]:
product_type = determine_data_type(data_object["WorkProduct"].get("ResourceTypeID")) product_type = determine_data_type(data_object["WorkProduct"].get("ResourceTypeID"))
validate_file_type(product_type, data_object) validate_file_type(product_type, data_object)
if product_type.lower() == "workproduct" and \ if product_type.lower() == "workproduct" and \
data_object.get("WorkProductComponents") and \ data_object.get("WorkProductComponents") and \
len(data_object["WorkProductComponents"]) >= 1: len(data_object["WorkProductComponents"]) >= 1:
product_type = determine_data_type( product_type = determine_data_type(
data_object["WorkProductComponents"][0].get("ResourceTypeID")) data_object["WorkProductComponents"][0].get("ResourceTypeID"))
validate_file_type(product_type, data_object) validate_file_type(product_type, data_object)
...@@ -183,6 +200,7 @@ def validate_file(loaded_conf) -> Tuple[FileType, str]: ...@@ -183,6 +200,7 @@ def validate_file(loaded_conf) -> Tuple[FileType, str]:
f"Error with file {data_object}. It doesn't have either Manifest or WorkProduct or ResourceTypeID.") f"Error with file {data_object}. It doesn't have either Manifest or WorkProduct or ResourceTypeID.")
sys.exit(2) sys.exit(2)
def create_kind(data_kind, conf_payload): def create_kind(data_kind, conf_payload):
partition_id = conf_payload.get("data-partition-id", DEFAULT_TENANT) partition_id = conf_payload.get("data-partition-id", DEFAULT_TENANT)
source = conf_payload.get("authority", DEFAULT_SOURCE) source = conf_payload.get("authority", DEFAULT_SOURCE)
...@@ -241,7 +259,9 @@ def create_manifest_request_data(loaded_conf: dict, product_type: str): ...@@ -241,7 +259,9 @@ def create_manifest_request_data(loaded_conf: dict, product_type: str):
legal_tag = loaded_conf.get("legal_tag") legal_tag = loaded_conf.get("legal_tag")
data_object = loaded_conf.get("data_object") data_object = loaded_conf.get("data_object")
data_objects_list = [ data_objects_list = [
(populate_request_body(data_object["Manifest"], acl, legal_tag, product_type), product_type)] (
populate_request_body(data_object["Manifest"], acl, legal_tag, product_type),
product_type)]
return data_objects_list return data_objects_list
...@@ -251,19 +271,11 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp ...@@ -251,19 +271,11 @@ def create_workproduct_request_data(loaded_conf: dict, product_type: str, wp, wp
return data_objects_list return data_objects_list
def send_request(request_data, conf_payload): @refresh_token
def send_request(headers, request_data):
""" """
Send request to records storage API Send request to records storage API
""" """
auth = conf_payload["authorization"]
partition_id = conf_payload["data-partition-id"]
app_key = conf_payload["AppKey"]
headers = {
'Content-type': 'application/json',
'data-partition-id': partition_id,
'Authorization': auth,
'AppKey': app_key
}
logger.error(f"Header {str(headers)}") logger.error(f"Header {str(headers)}")
# loop for implementing retries send process # loop for implementing retries send process
...@@ -277,7 +289,7 @@ def send_request(request_data, conf_payload): ...@@ -277,7 +289,7 @@ def send_request(request_data, conf_payload):
if response.status_code in DATA_LOAD_OK_RESPONSE_CODES: if response.status_code in DATA_LOAD_OK_RESPONSE_CODES:
file_logger.info(",".join(map(str, response.json()["recordIds"]))) file_logger.info(",".join(map(str, response.json()["recordIds"])))
break return response
reason = response.text[:250] reason = response.text[:250]
logger.error(f"Request error.") logger.error(f"Request error.")
...@@ -287,6 +299,7 @@ def send_request(request_data, conf_payload): ...@@ -287,6 +299,7 @@ def send_request(request_data, conf_payload):
if retry + 1 < retries: if retry + 1 < retries:
if response.status_code in BAD_TOKEN_RESPONSE_CODES: if response.status_code in BAD_TOKEN_RESPONSE_CODES:
logger.error("Invalid or expired token.") logger.error("Invalid or expired token.")
return response
else: else:
time_to_sleep = TIMEOUT time_to_sleep = TIMEOUT
...@@ -303,23 +316,25 @@ def send_request(request_data, conf_payload): ...@@ -303,23 +316,25 @@ def send_request(request_data, conf_payload):
logger.error(f"Request could not be completed.\n" logger.error(f"Request could not be completed.\n"
f"Reason: {reason}") f"Reason: {reason}")
sys.exit(2) sys.exit(2)
return response.json()["recordIds"]
def process_manifest(**kwargs): def process_manifest(**kwargs):
loaded_conf, conf_payload = dataload(**kwargs) loaded_conf, conf_payload = dataload(**kwargs)
file_type, product_type = validate_file(loaded_conf) file_type, product_type = validate_file(loaded_conf)
if file_type is FileType.MANIFEST: if file_type is FileType.MANIFEST:
request = create_manifest_request_data(loaded_conf, product_type) manifest_record = create_manifest_request_data(loaded_conf, product_type)
elif file_type is FileType.WORKPRODUCT: elif file_type is FileType.WORKPRODUCT:
file_list, file_ids = process_file_items(loaded_conf, conf_payload) file_list, file_ids = process_file_items(loaded_conf, conf_payload)
kwargs["ti"].xcom_push(key="file_ids", value=file_ids) kwargs["ti"].xcom_push(key="file_ids", value=file_ids)
wpc_list, wpc_ids = process_wpc_items(loaded_conf, product_type, file_ids, conf_payload) wpc_list, wpc_ids = process_wpc_items(loaded_conf, product_type, file_ids, conf_payload)
wp_list = process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload) wp_list = process_wp_item(loaded_conf, product_type, wpc_ids, conf_payload)
request = create_workproduct_request_data(loaded_conf, product_type, wp_list, wpc_list, file_list) manifest_record = create_workproduct_request_data(loaded_conf, product_type, wp_list,
wpc_list,
file_list)
else: else:
sys.exit(2) sys.exit(2)
record_ids = send_request(request, conf_payload) headers = create_headers(conf_payload)
record_ids = send_request(headers, request_data=manifest_record).json()["recordIds"]
kwargs["ti"].xcom_push(key="record_ids", value=record_ids) kwargs["ti"].xcom_push(key="record_ids", value=record_ids)
......
...@@ -14,28 +14,27 @@ ...@@ -14,28 +14,27 @@
# limitations under the License. # limitations under the License.
import enum
import json import json
import logging import logging
import sys import sys
from functools import partial
from typing import Tuple from typing import Tuple
import tenacity import tenacity
from airflow.models import BaseOperator, Variable from airflow.models import BaseOperator, Variable
from airflow.utils.decorators import apply_defaults from airflow.utils.decorators import apply_defaults
from hooks import search_http_hook, workflow_hook from hooks import search_http_hook, workflow_hook
from libs.exceptions import RecordsNotSearchableError
from libs.refresh_token import refresh_token
# Set up base logger # Set up base logger
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) handler.setFormatter(
logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s"))
logger = logging.getLogger("Dataload") logger = logging.getLogger("Dataload")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(handler) logger.addHandler(handler)
class SearchRecordIdOperator(BaseOperator): class SearchRecordIdOperator(BaseOperator):
""" """
Operator to search files in SearchService by record ids. Operator to search files in SearchService by record ids.
...@@ -49,32 +48,25 @@ class SearchRecordIdOperator(BaseOperator): ...@@ -49,32 +48,25 @@ class SearchRecordIdOperator(BaseOperator):
FAILED_STATUS = "failed" FAILED_STATUS = "failed"
@apply_defaults @apply_defaults
def __init__( self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.workflow_hook = workflow_hook self.workflow_hook = workflow_hook
self.search_hook = search_http_hook self.search_hook = search_http_hook
# the will be set at the beginning of execute method
@staticmethod self.request_body = None
def _file_searched(resp, expected_total_count) -> bool: self.expected_total_count = None
"""Check if search service returns expected totalCount.
The method is used as a callback
"""
data = resp.json()
return data.get("totalCount") == expected_total_count
def get_headers(self, **kwargs) -> dict: def get_headers(self, **kwargs) -> dict:
data_conf = kwargs['dag_run'].conf data_conf = kwargs['dag_run'].conf
# for /submitWithManifest authorization and partition-id are inside Payload field # for /submitWithManifest authorization and partition-id are inside Payload field
if "Payload" in data_conf: if "Payload" in data_conf:
auth = data_conf["Payload"]["authorization"]
partition_id = data_conf["Payload"]["data-partition-id"] partition_id = data_conf["Payload"]["data-partition-id"]
else: else:
auth = data_conf["authorization"]
partition_id = data_conf["data-partition-id"] partition_id = data_conf["data-partition-id"]
headers = { headers = {
'Content-type': 'application/json', 'Content-type': 'application/json',
'data-partition-id': partition_id, 'data-partition-id': partition_id,
'Authorization': auth, 'Authorization': "",
} }
return headers return headers
...@@ -86,35 +78,51 @@ class SearchRecordIdOperator(BaseOperator): ...@@ -86,35 +78,51 @@ class SearchRecordIdOperator(BaseOperator):
query = f"id:({record_ids})" query = f"id:({record_ids})"
return query, expected_total_count return query, expected_total_count
def search_files(self, **kwargs): def _create_request_body(self, **kwargs):
record_ids = kwargs["ti"].xcom_pull(key="record_ids",) record_ids = kwargs["ti"].xcom_pull(key="record_ids", )
if record_ids: if record_ids:
query, expected_total_count = self._create_search_query(record_ids) query, expected_total_count = self._create_search_query(record_ids)
else: else:
logger.error("There are no record ids") logger.error("There are no record ids")
sys.exit(2) sys.exit(2)
headers = self.get_headers(**kwargs)
request_body = { request_body = {
"kind": "*:*:*:*", "kind": "*:*:*:*",
"query": query "query": query
} }
retry_opts = { return request_body, expected_total_count
"wait": tenacity.wait_exponential(multiplier=5),
"stop": tenacity.stop_after_attempt(5), def _is_record_searchable(self, resp) -> bool:
"retry": tenacity.retry_if_not_result( """
partial(self._file_searched, expected_total_count=expected_total_count) Check if search service returns expected totalCount of records.
"""
data = resp.json()
return data.get("totalCount") == self.expected_total_count
@tenacity.retry(tenacity.wait_exponential(multiplier=5), tenacity.stop_after_attempt(5))
@refresh_token
def search_files(self, headers, **kwargs):
if self.request_body:
response = self.search_hook.run(
endpoint=Variable.get("search_query_ep"),
headers=headers,
data=json.dumps(self.request_body),
extra_options={"check_response": False}
) )
} if not self._is_record_searchable(response):
self.search_hook.run_with_advanced_retry( logger.error("Expected amount (%s) of records not found." %
endpoint=Variable.get("search_query_ep"), self.expected_total_count
headers=headers, )
data=json.dumps(request_body), raise RecordsNotSearchableError
_retry_args=retry_opts return response
) else:
logger.error("There is an error in header or in request body")
sys.exit(2)
def execute(self, context): def execute(self, context):
"""Execute update workflow status. """Execute update workflow status.
If status assumed to be FINISHED then we check whether proceed files are searchable or not. If status assumed to be FINISHED then we check whether proceed files are searchable or not.
If they are then update status FINISHED else FAILED If they are then update status FINISHED else FAILED
""" """
self.search_files(**context) self.request_body, self.expected_total_count = self._create_request_body(**context)
headers = self.get_headers(**context)
self.search_files(headers, **context)
...@@ -23,20 +23,20 @@ from functools import partial ...@@ -23,20 +23,20 @@ from functools import partial
import tenacity import tenacity
from airflow.models import BaseOperator, Variable from airflow.models import BaseOperator, Variable
from airflow.utils.decorators import apply_defaults from airflow.utils.decorators import apply_defaults
from hooks import search_http_hook, workflow_hook from hooks import search_http_hook, workflow_hook
from libs.exceptions import PipelineFailedError
from libs.refresh_token import refresh_token
# Set up base logger # Set up base logger
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s")) handler.setFormatter(
logging.Formatter("%(asctime)s [%(name)-14.14s] [%(levelname)-7.7s] %(message)s"))
logger = logging.getLogger("Dataload") logger = logging.getLogger("Dataload")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(handler) logger.addHandler(handler)
class UpdateStatusOperator(BaseOperator): class UpdateStatusOperator(BaseOperator):
ui_color = '#10ECAA' ui_color = '#10ECAA'
ui_fgcolor = '#000000' ui_fgcolor = '#000000'
...@@ -50,7 +50,7 @@ class UpdateStatusOperator(BaseOperator): ...@@ -50,7 +50,7 @@ class UpdateStatusOperator(BaseOperator):
FAILED = enum.auto() FAILED = enum.auto()
@apply_defaults @apply_defaults
def __init__( self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.workflow_hook = workflow_hook self.workflow_hook = workflow_hook
self.search_hook = search_http_hook self.search_hook = search_http_hook
...@@ -67,15 +67,12 @@ class UpdateStatusOperator(BaseOperator): ...@@ -67,15 +67,12 @@ class UpdateStatusOperator(BaseOperator):
data_conf = kwargs['dag_run'].conf data_conf = kwargs['dag_run'].conf
# for /submitWithManifest authorization and partition-id are inside Payload field # for /submitWithManifest authorization and partition-id are inside Payload field
if "Payload" in data_conf: if "Payload" in data_conf:
auth = data_conf["Payload"]["authorization"]
partition_id = data_conf["Payload"]["data-partition-id"] partition_id = data_conf["Payload"]["data-partition-id"]
else: else:
auth = data_conf["authorization"]
partition_id = data_conf["data-partition-id"] partition_id = data_conf["data-partition-id"]
headers = { headers = {
'Content-type': 'application/json', 'Content-type': 'application/json',
'data-partition-id': partition_id, 'data-partition-id': partition_id,
'Authorization': auth,
} }
return headers return headers
...@@ -114,8 +111,9 @@ class UpdateStatusOperator(BaseOperator): ...@@ -114,8 +111,9 @@ class UpdateStatusOperator(BaseOperator):
def previous_ti_statuses(self, context): def previous_ti_statuses(self, context):
dagrun = context['ti'].get_dagrun() dagrun = context['ti'].get_dagrun()
failed_ti, success_ti = dagrun.get_task_instances(state='failed'), dagrun.get_task_instances(state='success') failed_ti, success_ti = dagrun.get_task_instances(
if not failed_ti and not success_ti: # There is no prev task so it can't have been failed state='failed'), dagrun.get_task_instances(state='success')
if not failed_ti and not success_ti: # There is no prev task so it can't have been failed
logger.info("There are no tasks before this one. So it has status RUNNING") logger.info("There are no tasks before this one. So it has status RUNNING")
return self.prev_ti_state.NONE return self.prev_ti_state.NONE
if failed_ti: if failed_ti:
...@@ -138,22 +136,25 @@ class UpdateStatusOperator(BaseOperator): ...@@ -138,22 +136,25 @@ class UpdateStatusOperator(BaseOperator):
If status assumed to be FINISHED then we check whether proceed files are searchable or not. If status assumed to be FINISHED then we check whether proceed files are searchable or not.
If they are then update status FINISHED else FAILED If they are then update status FINISHED else FAILED
""" """
self.update_status_rqst(self.status, **context) headers = self.get_headers(**context)
self.update_status_request(headers, self.status, **context)
if self.status == self.FAILED_STATUS: if self.status == self.FAILED_STATUS:
raise Exception("Dag failed") raise PipelineFailedError("Dag failed")
def update_status_rqst(self, status, **kwargs): @refresh_token
def update_status_request(self, headers, status, **kwargs):
data_conf = kwargs['dag_run'].conf data_conf = kwargs['dag_run'].conf
logger.info(f"Got dataconf {data_conf}") logger.info(f"Got dataconf {data_conf}")
workflow_id = data_conf["WorkflowID"] workflow_id = data_conf["WorkflowID"]
headers = self.get_headers(**kwargs)
request_body = { request_body = {
"WorkflowID": workflow_id, "WorkflowID": workflow_id,
"Status": status "Status": status
} }
logger.info(f" Sending request '{status}'") logger.info(f" Sending request '{status}'")
self.workflow_hook.run( response = self.workflow_hook.run(
endpoint=Variable.get("update_status_ep"), endpoint=Variable.get("update_status_ep"),
data=json.dumps(request_body), data=json.dumps(request_body),
headers=headers headers=headers,
extra_options={"check_response": False}
) )
return response
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
# limitations under the License. # limitations under the License.
from flask import Flask, url_for, request from flask import Flask, json, request, url_for
from flask import json
OSDU_INGEST_SUCCES_FIFO = "/tmp/osdu_ingest_success" OSDU_INGEST_SUCCES_FIFO = "/tmp/osdu_ingest_success"
OSDU_INGEST_FAILED_FIFO = "/tmp/osdu_ingest_failed" OSDU_INGEST_FAILED_FIFO = "/tmp/osdu_ingest_failed"
......
...@@ -13,4 +13,4 @@ ...@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .process_manifest_op import * from .process_manifest_op import *
\ No newline at end of file
...@@ -14,45 +14,47 @@ ...@@ -14,45 +14,47 @@
# limitations under the License. # limitations under the License.
import re
import os import os
import pytest import re
import sys import sys
import pytest
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins")
from operators import process_manifest_op as p_m_op
from data import process_manifest_op as test_data from data import process_manifest_op as test_data
from operators import process_manifest_op
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_input, expected", "test_input, expected",
[ [
("srn:type:work-product/WellLog:", "WellLog"), ("srn:type:work-product/WellLog:", "WellLog"),
("srn:type:file/las2:", "las2"), ("srn:type:file/las2:", "las2"),
] ]
) )
def test_determine_data_type(test_input, expected): def test_determine_data_type(test_input, expected):
data_type = p_m_op.determine_data_type(test_input) data_type = process_manifest_op.determine_data_type(test_input)
assert data_type == expected assert data_type == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
"data_type, loaded_conf, conf_payload, expected_file_result", "data_type, loaded_conf, conf_payload, expected_file_result",
[ [
("las2", ("las2",
test_data.LOADED_CONF, test_data.LOADED_CONF,
test_data.CONF_PAYLOAD, test_data.CONF_PAYLOAD,
test_data.PROCESS_FILE_ITEMS_RESULT) test_data.PROCESS_FILE_ITEMS_RESULT)
] ]
) )
def test_process_file_items(data_type, loaded_conf, conf_payload, expected_file_result): def test_process_file_items(data_type, loaded_conf, conf_payload, expected_file_result):
file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:") file_id_regex = re.compile(r"srn\:file/" + data_type + r"\:\d+\:")
expected_file_list = expected_file_result[0] expected_file_list = expected_file_result[0]
file_list, file_ids = p_m_op.process_file_items(loaded_conf, conf_payload) file_list, file_ids = process_manifest_op.process_file_items(loaded_conf, conf_payload)
for i in file_ids: for i in file_ids:
assert file_id_regex.match(i) assert file_id_regex.match(i)
for i in file_list: for i in file_list:
assert file_id_regex.match(i[0]["data"]["ResourceID"]) assert file_id_regex.match(i[0]["data"]["ResourceID"])
i[0]["data"]["ResourceID"] = "" i[0]["data"]["ResourceID"] = ""
assert file_list == expected_file_list assert file_list == expected_file_list
...@@ -51,6 +51,9 @@ airflow variables -s update_status_ep wf/us ...@@ -51,6 +51,9 @@ airflow variables -s update_status_ep wf/us
airflow variables -s search_url $LOCALHOST airflow variables -s search_url $LOCALHOST
airflow variables -s dataload_config_path $DATALOAD_CONFIG_PATH airflow variables -s dataload_config_path $DATALOAD_CONFIG_PATH
airflow variables -s search_query_ep sr/qr airflow variables -s search_query_ep sr/qr
airflow variables -s access_token test
airflow variables -s "sa-file-osdu" "test"
airflow connections -a --conn_id search --conn_uri $SEARCH_CONN_ID airflow connections -a --conn_id search --conn_uri $SEARCH_CONN_ID
airflow connections -a --conn_id workflow --conn_uri $WORKFLOW_CONN_ID airflow connections -a --conn_id workflow --conn_uri $WORKFLOW_CONN_ID
airflow connections -a --conn_id google_cloud_storage --conn_uri $WORKFLOW_CONN_ID airflow connections -a --conn_id google_cloud_storage --conn_uri $WORKFLOW_CONN_ID
......
...@@ -17,28 +17,28 @@ import enum ...@@ -17,28 +17,28 @@ import enum
import subprocess import subprocess
import time import time
class DagStatus(enum.Enum): class DagStatus(enum.Enum):
RUNNING = enum.auto() RUNNING = "running"
FAILED = enum.auto() FAILED = "failed"
FINISHED = enum.auto() FINISHED = "finished"
OSDU_INGEST_SUCCESS_SH = "/mock-server/./test-osdu-ingest-success.sh" OSDU_INGEST_SUCCESS_SH = "/mock-server/./test-osdu-ingest-success.sh"
OSDU_INGEST_FAIL_SH = "/mock-server/./test-osdu-ingest-fail.sh" OSDU_INGEST_FAIL_SH = "/mock-server/./test-osdu-ingest-fail.sh"
DEFAULT_INGEST_SUCCESS_SH = "/mock-server/./test-default-ingest-success.sh" DEFAULT_INGEST_SUCCESS_SH = "/mock-server/./test-default-ingest-success.sh"
DEFAULT_INGEST_FAIL_SH = "/mock-server/./test-default-ingest-fail.sh" DEFAULT_INGEST_FAIL_SH = "/mock-server/./test-default-ingest-fail.sh"
with open("/tmp/osdu_ingest_result", "w") as f:
f.close()
subprocess.run(f"/bin/bash -c 'airflow scheduler > /dev/null 2>&1 &'", shell=True) subprocess.run(f"/bin/bash -c 'airflow scheduler > /dev/null 2>&1 &'", shell=True)
def check_dag_status(dag_name):
def check_dag_status(dag_name: str) -> DagStatus:
time.sleep(5) time.sleep(5)
output = subprocess.getoutput(f'airflow list_dag_runs {dag_name}') output = subprocess.getoutput(f'airflow list_dag_runs {dag_name}')
if "failed" in output: if "failed" in output:
print(dag_name) print(dag_name)
print(output) print(output)
return DagStatus.FAILED return DagStatus.FAILED
if "running" in output: if "running" in output:
return DagStatus.RUNNING return DagStatus.RUNNING
print(dag_name) print(dag_name)
...@@ -46,32 +46,17 @@ def check_dag_status(dag_name): ...@@ -46,32 +46,17 @@ def check_dag_status(dag_name):
return DagStatus.FINISHED return DagStatus.FINISHED
def test_dag_success(dag_name, script): def test_dag_execution_result(dag_name: str, script: str, expected_status: DagStatus):
print(f"Test {dag_name} success")
subprocess.run(f"{script}", shell=True)
while True:
dag_status = check_dag_status(dag_name)
if dag_status is DagStatus.RUNNING:
continue
elif dag_status is DagStatus.FINISHED:
return
else:
raise Exception(f"Error {dag_name} supposed to be finished")
def test_dag_fail(dag_name, script):
subprocess.run(f"{script}", shell=True) subprocess.run(f"{script}", shell=True)
print(f"Expecting {dag_name} fail") print(f"Expecting {dag_name} to be {expected_status.value}")
while True: while True:
dag_status = check_dag_status(dag_name) dag_status = check_dag_status(dag_name)
if dag_status is DagStatus.RUNNING: if dag_status is not DagStatus.RUNNING:
continue break
elif dag_status is DagStatus.FAILED: assert dag_status is expected_status, f"Error {dag_name} supposed to be {expected_status.value}"
return
else:
raise Exception(f"Error {dag_name} supposed to be failed")
test_dag_success("Osdu_ingest", OSDU_INGEST_SUCCESS_SH) test_dag_execution_result("Osdu_ingest", OSDU_INGEST_SUCCESS_SH, DagStatus.FINISHED)
test_dag_fail("Osdu_ingest", OSDU_INGEST_FAIL_SH) test_dag_execution_result("Osdu_ingest", OSDU_INGEST_FAIL_SH, DagStatus.FAILED)
test_dag_success("Default_ingest", DEFAULT_INGEST_SUCCESS_SH) test_dag_execution_result("Default_ingest", DEFAULT_INGEST_SUCCESS_SH, DagStatus.FINISHED)
test_dag_fail("Default_ingest", DEFAULT_INGEST_FAIL_SH) test_dag_execution_result("Default_ingest", DEFAULT_INGEST_FAIL_SH, DagStatus.FAILED)
pip uninstall enum34 -y
pip install pytest pip install pytest
pip install --upgrade google-api-python-client pip install --upgrade google-api-python-client
chmod +x tests/set_airflow_env.sh chmod +x tests/set_airflow_env.sh
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment