diff --git a/src/dags/libs/exceptions.py b/src/dags/libs/exceptions.py index 747ea51b3b2297d3b1b77be85ff539477d7db984..839562f97723e06a29cd1ac67a465c1233f90f6d 100644 --- a/src/dags/libs/exceptions.py +++ b/src/dags/libs/exceptions.py @@ -59,3 +59,25 @@ class SAFilePathError(Exception): """ pass +class FileSourceError(Exception): + """ + Raise when file doesn't exist under given path. + """ + pass + +class GCSObjectURIError(Exception): + """ + Raise when wrong Google Storage Object was given. + """ + pass + +class UploadFileError(Exception): + """ + Raise when there is an error while uploading a file into OSDU. + """ + +class TokenRefresherNotPresentError(Exception): + """ + Raise when token refresher is not present in "refresh_token' decorator. + """ + pass diff --git a/src/dags/libs/process_manifest_r3.py b/src/dags/libs/process_manifest_r3.py index 4b2f55b060724b42d4c250268a37bcd909837325..bc24dcaab367ea0e6bce22efa1fd9fde823641a9 100644 --- a/src/dags/libs/process_manifest_r3.py +++ b/src/dags/libs/process_manifest_r3.py @@ -23,8 +23,10 @@ import requests import tenacity from libs.context import Context from libs.exceptions import EmptyManifestError, NotOSDUShemaFormatError +from libs.source_file_check import SourceFileChecker +from libs.upload_file import FileUploader from libs.mixins import HeadersMixin -from libs.refresh_token import AirflowTokenRefresher, refresh_token +from libs.refresh_token import AirflowTokenRefresher, refresh_token, TokenRefresher logger = logging.getLogger() @@ -43,11 +45,20 @@ class ManifestProcessor(HeadersMixin): } } - def __init__(self, storage_url: str, dagrun_conf: dict, context: Context): + def __init__( + self, + storage_url: str, + dagrun_conf: dict, + file_uploader: FileUploader, + token_refresher: TokenRefresher, + context: Context + ): super().__init__(context) + self.file_uploader = file_uploader self.storage_url = storage_url self.data_object = copy.deepcopy(dagrun_conf) self.context = context + self.token_refresher = token_refresher @staticmethod def _get_kind_name(kind: str) -> str: @@ -57,6 +68,11 @@ class ManifestProcessor(HeadersMixin): kind_name = kind.split(":")[2] return kind_name + def upload_source_file(self, file_record: dict) -> dict: + file_path = file_record["data"]["PreLoadFilePath"] + file_record["data"]["FileSource"] = self.file_uploader.upload_file(file_path) + return file_record + def generate_id(self, manifest_fragment: dict) -> str: """ Generate id to use it in Storage. @@ -79,7 +95,14 @@ class ManifestProcessor(HeadersMixin): record["data"] = manifest return record - def _validate_storage_response(self, response_dict: dict): + @staticmethod + def _validate_file_source_checker_type(file_source_checker: SourceFileChecker): + if not isinstance(file_source_checker, SourceFileChecker): + raise TypeError(f"File checker must be instance of '{SourceFileChecker}' class.\n" + f"Got got instance of '{file_source_checker}'") + + @staticmethod + def _validate_storage_response(response_dict: dict): if not ( isinstance(response_dict, dict) and isinstance(response_dict.get("recordIds"), list) @@ -97,14 +120,15 @@ class ManifestProcessor(HeadersMixin): Send request to record storage API. """ request_data = json.dumps(request_data) - logger.info("Send to Storage service") - logger.info(f"{request_data}") + logger.info("Sending records to Storage service") + logger.debug(f"{request_data}") response = requests.put(self.storage_url, request_data, headers=headers) if response.ok: response_dict = response.json() self._validate_storage_response(response_dict) - logger.info(f"Response: {response_dict}") - logger.info(",".join(map(str, response_dict["recordIds"]))) + record_ids = ", ".join(map(str, response_dict["recordIds"])) + logger.debug(f"Response: {response_dict}") + logger.info(f"Records '{record_ids}' were saved using Storage service.") else: reason = response.text[:250] logger.error(f"Request error.") @@ -136,7 +160,8 @@ class ManifestProcessor(HeadersMixin): """ records = [] for file in manifest["Files"]: - record = self.populate_manifest_storage_record(file) + record = self.upload_source_file(file) + record = self.populate_manifest_storage_record(record) records.append(record) return records diff --git a/src/dags/libs/refresh_token.py b/src/dags/libs/refresh_token.py index 6228694a4a84de3072358d9eb1b3bf2e31b9360c..dc0a2eba93e7730f2490967c821f2edf9e13f7f8 100644 --- a/src/dags/libs/refresh_token.py +++ b/src/dags/libs/refresh_token.py @@ -26,7 +26,7 @@ import requests from google.auth.transport.requests import Request from google.cloud import storage from google.oauth2 import service_account -from libs.exceptions import RefreshSATokenError, SAFilePathError +from libs.exceptions import RefreshSATokenError, SAFilePathError, TokenRefresherNotPresentError from tenacity import retry, stop_after_attempt logger = logging.getLogger() @@ -181,7 +181,7 @@ def make_callable_request(obj: Union[object, None], request_function: Callable, def _validate_headers_type(headers: dict): if not isinstance(headers, dict): logger.error(f"Got headers {headers}") - raise TypeError + raise TypeError("Request's headers type expected to be 'dict'") def _validate_response_type(response: requests.Response, request_function: Callable): @@ -199,20 +199,38 @@ def _validate_token_refresher_type(token_refresher: TokenRefresher): ) +def _get_object_token_refresher( + token_refresher: TokenRefresher, + obj: object = None +) -> TokenRefresher: + """ + Check if token refresher passed into decorator or specified in object's as 'token_refresher' + property. + """ + if token_refresher is None and obj: + try: + obj.__getattribute__("token_refresher") + except AttributeError: + raise TokenRefresherNotPresentError("Token refresher must be passed into decorator or " + "set as object's 'refresh_token' attribute.") + else: + token_refresher = obj.token_refresher + return token_refresher + + def send_request_with_auth_header(token_refresher: TokenRefresher, *args, **kwargs) -> requests.Response: """ Send request with authorization token. If response status is in HTTPStatus.UNAUTHORIZED or HTTPStatus.FORBIDDEN, then refresh token and send request once again. """ - obj = kwargs.pop("obj") if kwargs.get("obj") else None + obj = kwargs.pop("obj", None) request_function = kwargs.pop("request_function") headers = kwargs.pop("headers") _validate_headers_type(headers) headers.update(token_refresher.authorization_header) - send_request_with_auth = make_callable_request(obj, request_function, headers, - *args, **kwargs) + send_request_with_auth = make_callable_request(obj, request_function, headers, *args, **kwargs) response = send_request_with_auth() _validate_response_type(response, request_function) @@ -233,7 +251,7 @@ def send_request_with_auth_header(token_refresher: TokenRefresher, *args, return response -def refresh_token(token_refresher: TokenRefresher) -> Callable: +def refresh_token(token_refresher: TokenRefresher = None) -> Callable: """ Wrap a request function and check response. If response's error status code is about Authorization, refresh token and invoke this function once again. @@ -244,13 +262,13 @@ def refresh_token(token_refresher: TokenRefresher) -> Callable: request_method(self, header: dict, *args, **kwargs) -> requests.Response """ - _validate_token_refresher_type(token_refresher) - def refresh_token_wrapper(request_function: Callable) -> Callable: is_method = len(request_function.__qualname__.split(".")) > 1 if is_method: def _wrapper(obj: object, headers: dict, *args, **kwargs) -> requests.Response: - return send_request_with_auth_header(token_refresher, + _token_refresher = _get_object_token_refresher(token_refresher, obj) + _validate_token_refresher_type(_token_refresher) + return send_request_with_auth_header(_token_refresher, request_function=request_function, obj=obj, headers=headers, @@ -258,6 +276,7 @@ def refresh_token(token_refresher: TokenRefresher) -> Callable: **kwargs) else: def _wrapper(headers: dict, *args, **kwargs) -> requests.Response: + _validate_token_refresher_type(token_refresher) return send_request_with_auth_header(token_refresher, request_function=request_function, headers=headers, diff --git a/src/dags/libs/search_record_ids.py b/src/dags/libs/search_record_ids.py index 73859305c24a42fbaff37e5740b92fa125bfcd82..e33a5ca1629bd9477769f2f2794e47f1914606ac 100644 --- a/src/dags/libs/search_record_ids.py +++ b/src/dags/libs/search_record_ids.py @@ -33,7 +33,7 @@ TIMEOUT = 10 class SearchId(HeadersMixin): - def __init__(self, search_url: str, record_ids: list, context: Context): + def __init__(self, search_url: str, record_ids: list, token_refresher, context: Context): super().__init__(context) if not record_ids: logger.error("There are no record ids") @@ -41,6 +41,7 @@ class SearchId(HeadersMixin): self.record_ids = record_ids self.search_url = search_url self.expected_total_count = len(record_ids) + self.token_refresher = token_refresher self._create_request_body() def _create_search_query(self) -> str: @@ -48,7 +49,7 @@ class SearchId(HeadersMixin): Create search query to send to Search service using recordIds need to be found. """ record_ids = " OR ".join(f"\"{id_}\"" for id_ in self.record_ids) - logger.info(f"Search query {record_ids}") + logger.debug(f"Search query {record_ids}") query = f"id:({record_ids})" return query @@ -67,10 +68,10 @@ class SearchId(HeadersMixin): """ Check if search service returns expected totalCount of records. """ - logger.info(response.text) + logger.debug(response.text) data = response.json() total_count = data.get('totalCount') - logger.info(f"Got total count {total_count}") + logger.debug(f"Got total count {total_count}") if total_count is None: raise ValueError(f"Got no totalCount field in Search service response. " f"Response is {data}.") @@ -81,7 +82,7 @@ class SearchId(HeadersMixin): stop=tenacity.stop_after_attempt(RETRIES), reraise=True ) - @refresh_token(AirflowTokenRefresher()) + @refresh_token() def search_files(self, headers: dict) -> requests.Response: """ Send request with recordIds to Search service. @@ -92,7 +93,10 @@ class SearchId(HeadersMixin): logger.error("Expected amount (%s) of records not found." % self.expected_total_count, ) - raise RecordsNotSearchableError + raise RecordsNotSearchableError( + f"Can't find records {self.request_body}. " + f"Got response {response.json()} from Search service." + ) return response def check_records_searchable(self): diff --git a/src/dags/libs/source_file_check.py b/src/dags/libs/source_file_check.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5468fa42f9648968e14f5f7d9957cedff4a5a8 --- /dev/null +++ b/src/dags/libs/source_file_check.py @@ -0,0 +1,78 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Tuple +from urllib.parse import urlparse + +import tenacity +from google.cloud import storage +from libs.exceptions import GCSObjectURIError, FileSourceError + + +RETRIES = 3 + + +class SourceFileChecker(ABC): + + @abstractmethod + def does_file_exist(self, file_path: str) -> bool: + """ + Validate if file exists under given file path. + """ + pass + + +class GCSSourceFileChecker(SourceFileChecker): + + def __init__(self): + pass + + def __repr__(self): + return "GCS file checker" + + def _parse_object_uri(self, file_path: str) -> Tuple[str, str]: + """ + Parse GCS Object uri. + Return bucket and blob names. + """ + parsed_path = urlparse(file_path) + if parsed_path.scheme == "gs": + bucket_name = parsed_path.netloc + source_blob_name = parsed_path.path[1:] # delete the first slash + + if bucket_name and source_blob_name: + return bucket_name, source_blob_name + + raise GCSObjectURIError(f"Wrong format path to GCS object. Object path is '{file_path}'") + + @tenacity.retry( + stop=tenacity.stop_after_attempt(RETRIES), + reraise=True + ) + def _does_file_exist_in_bucket(self, bucket_name: str, source_blob_name: str) -> bool: + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + does_exist = blob.exists() + return does_exist + + def does_file_exist(self, file_path: str) -> bool: + bucket_name, source_blob_name = self._parse_object_uri(file_path) + does_exist = self._does_file_exist_in_bucket(bucket_name, source_blob_name) + if not does_exist: + raise FileSourceError(f"File doesn't exist in '{file_path}'") + return True diff --git a/src/dags/libs/update_status.py b/src/dags/libs/update_status.py index e00ea18e9af2774aceccac9cf1797cdfbc8f1ea8..5b756747437dc8219c29e6270aac8ba91e139fc4 100644 --- a/src/dags/libs/update_status.py +++ b/src/dags/libs/update_status.py @@ -20,7 +20,7 @@ import logging import requests from libs.context import Context from libs.mixins import HeadersMixin -from libs.refresh_token import AirflowTokenRefresher, refresh_token +from libs.refresh_token import TokenRefresher, refresh_token logger = logging.getLogger() @@ -32,6 +32,7 @@ class UpdateStatus(HeadersMixin): workflow_id: str, workflow_url: str, status: str, + token_refresher: TokenRefresher, context: Context, ) -> None: super().__init__(context) @@ -39,15 +40,16 @@ class UpdateStatus(HeadersMixin): self.workflow_id = workflow_id self.context = context self.status = status + self.token_refresher = token_refresher - @refresh_token(AirflowTokenRefresher()) + @refresh_token() def update_status_request(self, headers: dict) -> requests.Response: request_body = { "WorkflowID": self.workflow_id, "Status": self.status } request_body = json.dumps(request_body) - logger.info(f" Sending request '{request_body}'") + logger.debug(f" Sending request '{request_body}'") response = requests.post(self.workflow_url, request_body, headers=headers) return response diff --git a/src/dags/libs/upload_file.py b/src/dags/libs/upload_file.py new file mode 100644 index 0000000000000000000000000000000000000000..73e260d2398b647465c95fbb81f831a835547a40 --- /dev/null +++ b/src/dags/libs/upload_file.py @@ -0,0 +1,157 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import io +import json +import logging +from abc import ABC, abstractmethod +from typing import Tuple, TypeVar +from urllib.parse import urlparse + +import requests +import tenacity +from google.cloud import storage +from libs.context import Context +from libs.exceptions import GCSObjectURIError, FileSourceError +from libs.mixins import HeadersMixin +from libs.refresh_token import TokenRefresher, refresh_token + +logger = logging.getLogger() + +FileLikeObject = TypeVar("FileLikeObject", io.IOBase, io.RawIOBase, io.BytesIO) + +RETRY_SETTINGS = { + "stop": tenacity.stop_after_attempt(3), + "wait": tenacity.wait_fixed(2), +} + + +class FileUploader(HeadersMixin, ABC): + """ + File uploader to copy file from PreLoadPath into FileSource on OSDU platform. + """ + + def __init__(self, file_service: str, token_refresher: TokenRefresher, context: Context): + super().__init__(context) + self.file_service = file_service + self.token_refresher = token_refresher + + @abstractmethod + def get_file_from_preload_path(self, preload_path: str) -> FileLikeObject: + """ + Return file-like object containing raw content of a file + in preload path. + """ + + @tenacity.retry(**RETRY_SETTINGS) + @refresh_token() + def _send_post_request(self, headers: dict, url: str, request_body: str) -> requests.Response: + response = requests.post(url, request_body, headers=headers) + return response + + @tenacity.retry(**RETRY_SETTINGS) + def _get_signed_url_request(self, headers: dict) -> Tuple[str, str]: + """ + Get fileID and SignedURL using File Service. + """ + logger.debug("Getting signed url.") + request_body = json.dumps({}) # '/getLocation' method requires empty json. + response = self._send_post_request(headers, f"{self.file_service}/getLocation", + request_body).json() + logger.debug("Signed url got.") + logger.debug(response) + return response["FileID"], response["Location"]["SignedURL"] + + @tenacity.retry(**RETRY_SETTINGS) + def _upload_file_request(self, headers: dict, signed_url: str, buffer: FileLikeObject): + """ + Upload file via File service using signed_url. + """ + logger.debug("Uploading file.") + buffer.seek(0) + requests.put(signed_url, buffer.read(), headers=headers) + logger.debug("File uploaded.") + + @tenacity.retry(**RETRY_SETTINGS) + def _get_file_location_request(self, headers: dict, file_id: str) -> str: + """ + Get file location using File Service. + """ + logger.debug("Getting file location.") + request_body = json.dumps({"FileID": file_id}) + response = self._send_post_request(headers, f"{self.file_service}/getFileLocation", + request_body).json() + logger.debug("File location got.") + return response["Location"] + + def upload_file(self, preload_file_path: str) -> str: + """ + Copy file from Landing zone(preload_file_path) onto OSDU platform using File service. + Return file_location. + """ + buffer = self.get_file_from_preload_path(preload_file_path) + file_id, signed_url = self._get_signed_url_request(self.request_headers) + self._upload_file_request(self.request_headers, signed_url, buffer) + file_location = self._get_file_location_request(self.request_headers, file_id) + return file_location + + +class GCSFileUploader(FileUploader): + + def __init__( + self, + file_service: str, + token_refresher: TokenRefresher, + context: Context, + ): + super().__init__(file_service, token_refresher, context) + + @staticmethod + def _parse_object_uri(file_path: str) -> Tuple[str, str]: + """ + Parse GCS Object uri. + Return bucket and blob names. + """ + parsed_path = urlparse(file_path) + if parsed_path.scheme == "gs": + bucket_name = parsed_path.netloc + source_blob_name = parsed_path.path[1:] # delete the first slash + + if bucket_name and source_blob_name: + return bucket_name, source_blob_name + + raise GCSObjectURIError + + @tenacity.retry(**RETRY_SETTINGS) + def get_file_from_bucket(self, bucket_name: str, source_blob_name: str) -> io.BytesIO: + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + + does_exist = blob.exists() + if not does_exist: + raise FileSourceError("File doesn't exist in preloadPath " + f"'gs://{bucket_name}/{source_blob_name}'") + + file = io.BytesIO() + blob.download_to_file(file) + logger.debug("File got from landing zone") + return file + + def get_file_from_preload_path(self, preload_file_path: str) -> io.BytesIO: + bucket_name, blob_name = self._parse_object_uri(preload_file_path) + buffer = self.get_file_from_bucket(bucket_name, blob_name) + return buffer diff --git a/src/dags/libs/validate_schema.py b/src/dags/libs/validate_schema.py index 99edb682c54cb295c776371dcf95f35868d7e043..9506e1281253828a1398c7feeac762f2e1d23200 100644 --- a/src/dags/libs/validate_schema.py +++ b/src/dags/libs/validate_schema.py @@ -22,7 +22,7 @@ import tenacity from libs.context import Context from libs.exceptions import EmptyManifestError, NotOSDUShemaFormatError from libs.mixins import HeadersMixin -from libs.refresh_token import AirflowTokenRefresher, refresh_token +from libs.refresh_token import TokenRefresher, refresh_token logger = logging.getLogger() @@ -55,11 +55,17 @@ class OSDURefResolver(jsonschema.RefResolver): class SchemaValidator(HeadersMixin): """Class to validate schema of Manifests.""" - def __init__(self, schema_service: str, dagrun_conf: dict, context: Context): + def __init__( + self, schema_service: str, + dagrun_conf: dict, + token_refresher: TokenRefresher, + context: Context + ): super().__init__(context) self.schema_service = schema_service self.data_object = copy.deepcopy(dagrun_conf) self.context = context + self.token_refresher = token_refresher self.resolver_handlers = { "osdu": self.get_schema_request, "https": self.get_schema_request, @@ -71,7 +77,7 @@ class SchemaValidator(HeadersMixin): stop=tenacity.stop_after_attempt(RETRIES), reraise=True ) - @refresh_token(AirflowTokenRefresher()) + @refresh_token() def _get_schema_from_schema_service(self, headers: dict, uri: str) -> requests.Response: """ Request to Schema service to retrieve schema. @@ -105,7 +111,7 @@ class SchemaValidator(HeadersMixin): """ if not schema: schema = self.get_schema(manifest["kind"]) - logger.info(f"Validating kind {manifest['kind']}") + logger.debug(f"Validating kind {manifest['kind']}") resolver = OSDURefResolver(schema_service=self.schema_service, base_uri=schema.get("$id", ""), referrer=schema, handlers=self.resolver_handlers, cache_remote=True) diff --git a/src/plugins/operators/process_manifest_r3.py b/src/plugins/operators/process_manifest_r3.py index 6850e41b363dd0dd59ac41990b4a599fa02484ac..f2ceb67d527fe4c60764e39f090273837bb1f6ef 100644 --- a/src/plugins/operators/process_manifest_r3.py +++ b/src/plugins/operators/process_manifest_r3.py @@ -18,10 +18,11 @@ import logging from airflow.utils import apply_defaults from airflow.models import BaseOperator, Variable from libs.context import Context +from libs.upload_file import GCSFileUploader +from libs.refresh_token import AirflowTokenRefresher from libs.process_manifest_r3 import ManifestProcessor from libs.validate_schema import SchemaValidator - logger = logging.getLogger() RETRIES = 3 @@ -37,19 +38,28 @@ class ProcessManifestOperatorR3(BaseOperator): super().__init__(*args, **kwargs) self.schema_service_url = Variable.get('schema_service_url') self.storage_url = Variable.get('storage_url') + self.file_service_url = Variable.get('file_service_url') def execute(self, context: dict): payload_context = Context.populate(context["dag_run"].conf) + token_refresher = AirflowTokenRefresher() + file_uploader = GCSFileUploader(self.file_service_url, token_refresher, + payload_context) + validator = SchemaValidator( self.schema_service_url, context["dag_run"].conf, + token_refresher, payload_context ) manifest_processor = ManifestProcessor( self.storage_url, context["dag_run"].conf, - payload_context + file_uploader, + token_refresher, + payload_context, ) + validator.validate_manifest() record_ids = manifest_processor.process_manifest() context["ti"].xcom_push(key="record_ids", value=record_ids) diff --git a/src/plugins/operators/search_record_id.py b/src/plugins/operators/search_record_id.py index eb5b36679f0985b2188a8fb4025085bf0751e5e2..d122dc6507e74320adfffeb928c486e481030020 100644 --- a/src/plugins/operators/search_record_id.py +++ b/src/plugins/operators/search_record_id.py @@ -17,6 +17,7 @@ import logging from airflow.models import BaseOperator, Variable from libs.context import Context +from libs.refresh_token import AirflowTokenRefresher from libs.search_record_ids import SearchId logger = logging.getLogger() @@ -41,5 +42,6 @@ class SearchRecordIdOperator(BaseOperator): """ payload_context = Context.populate(context["dag_run"].conf) record_ids = context["ti"].xcom_pull(key="record_ids", ) - ids_searcher = SearchId(Variable.get("search_url"), record_ids, payload_context, ) + ids_searcher = SearchId(Variable.get("search_url"), record_ids, AirflowTokenRefresher(), + payload_context) ids_searcher.check_records_searchable() diff --git a/src/plugins/operators/update_status.py b/src/plugins/operators/update_status.py index 06eee9d6ef3af244af463aa0995ed5de1d77afdf..c2b70bca91119b6901d15a6e09b223aa3e92f2f0 100644 --- a/src/plugins/operators/update_status.py +++ b/src/plugins/operators/update_status.py @@ -20,6 +20,7 @@ import logging from airflow.models import BaseOperator, Variable from libs.context import Context +from libs.refresh_token import AirflowTokenRefresher from libs.exceptions import PipelineFailedError from libs.update_status import UpdateStatus @@ -62,7 +63,7 @@ class UpdateStatusOperator(BaseOperator): If they are then update status FINISHED else FAILED """ conf = copy.deepcopy(context["dag_run"].conf) - logger.info(f"Got conf {conf}.") + logger.debug(f"Got conf {conf}.") if "Payload" in conf: payload_context = Context.populate(conf) else: @@ -74,6 +75,7 @@ class UpdateStatusOperator(BaseOperator): workflow_url=Variable.get("update_status_url"), workflow_id=workflow_id, status=status, + token_refresher=AirflowTokenRefresher(), context=payload_context ) status_updater.update_workflow_status() diff --git a/tests/plugin-unit-tests/data/workProduct/SeismicTraceData.json b/tests/plugin-unit-tests/data/workProduct/SeismicTraceData.json index 69463987d472af381678e8e3e15460108c36b688..9c4035d36cadd21c16e54b8a25b0ad1844a6f344 100644 --- a/tests/plugin-unit-tests/data/workProduct/SeismicTraceData.json +++ b/tests/plugin-unit-tests/data/workProduct/SeismicTraceData.json @@ -375,7 +375,7 @@ "resourceSecurityClassification": "srn:opendes:reference-data/ResourceSecurityClassification:RESTRICTED:", "data": { "SchemaFormatTypeID": "srn:opendes:reference-data/SchemaFormatType:SEG-Y Seismic Trace Data:", - "PreLoadFilePath": "C:\\Seismic\\ST0202R08_PS_PSDM_RAW_PP_TIME.MIG_RAW.POST_STACK.3D.JS-017534.segy", + "PreLoadFilePath": "test", "FileSource": "", "FileSize": 277427976, "EncodingFormatTypeID": "srn:opendes:reference-data/EncodingFormatType:segy:", diff --git a/tests/plugin-unit-tests/data/workProduct/record_SeismicTraceData.json b/tests/plugin-unit-tests/data/workProduct/record_SeismicTraceData.json index 5e8b95acaf851aee3222f9697cf4f7d3d3464115..0c39e27d8b313625fb546065c412b43440e81ee2 100644 --- a/tests/plugin-unit-tests/data/workProduct/record_SeismicTraceData.json +++ b/tests/plugin-unit-tests/data/workProduct/record_SeismicTraceData.json @@ -24,8 +24,8 @@ "resourceSecurityClassification": "srn:opendes:reference-data/ResourceSecurityClassification:RESTRICTED:", "data": { "SchemaFormatTypeID": "srn:opendes:reference-data/SchemaFormatType:SEG-Y Seismic Trace Data:", - "PreLoadFilePath": "C:\\Seismic\\ST0202R08_PS_PSDM_RAW_PP_TIME.MIG_RAW.POST_STACK.3D.JS-017534.segy", - "FileSource": "", + "PreLoadFilePath": "test", + "FileSource": "test", "FileSize": 277427976, "EncodingFormatTypeID": "srn:opendes:reference-data/EncodingFormatType:segy:", "Endian": "BIG", diff --git a/tests/plugin-unit-tests/test_file_checker.py b/tests/plugin-unit-tests/test_file_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..cf09ee908e7b38aa76ae88b02106bc1f87809483 --- /dev/null +++ b/tests/plugin-unit-tests/test_file_checker.py @@ -0,0 +1,91 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys + +sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") +sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") + +import pytest +from libs.exceptions import FileSourceError, GCSObjectURIError +from libs.source_file_check import GCSSourceFileChecker + + +class TestSourceFileChecker: + + @pytest.fixture() + def file_checker(self): + return GCSSourceFileChecker() + + @pytest.fixture() + def mock_file_exist_under_uri(self, monkeypatch, file_exists: bool): + """ + Mock response from GCS if file exists or not. + """ + monkeypatch.setattr(GCSSourceFileChecker, "_does_file_exist_in_bucket", + lambda *args, **kwargs: file_exists) + + @pytest.mark.parametrize( + "file_path, file_exists", + [ + pytest.param("gs://test/test", True, id="Valid URI") + ] + ) + def test_does_file_exist_in_bucket( + self, + monkeypatch, + mock_file_exist_under_uri, + file_exists: bool, + file_checker: GCSSourceFileChecker, + file_path: str + ): + """ + Test if file does really exist. + """ + file_checker.does_file_exist(file_path) + + @pytest.mark.parametrize( + "file_path, file_exists", + [ + pytest.param("gs://test/test", False) + ] + ) + def test_file_does_not_exist_in_bucket( + self, + monkeypatch, + mock_file_exist_under_uri, + file_exists: bool, + file_checker: GCSSourceFileChecker, + file_path: str + ): + """ + Test if file doesn't exist. + """ + with pytest.raises(FileSourceError): + file_checker.does_file_exist(file_path) + + @pytest.mark.parametrize( + "file_path", + [ + pytest.param("gs://test"), + pytest.param("://test"), + pytest.param("test"), + ] + ) + def test_invalid_gcs_object_uri(self, file_checker: GCSSourceFileChecker, file_path): + with pytest.raises(GCSObjectURIError): + file_checker._parse_object_uri(file_path) diff --git a/tests/plugin-unit-tests/test_file_uplaod.py b/tests/plugin-unit-tests/test_file_uplaod.py new file mode 100644 index 0000000000000000000000000000000000000000..9042b30cf5f7138c702c30498676067a379058f2 --- /dev/null +++ b/tests/plugin-unit-tests/test_file_uplaod.py @@ -0,0 +1,65 @@ +# Copyright 2020 Google LLC +# Copyright 2020 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import sys + +sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") +sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") + +from libs.exceptions import GCSObjectURIError +import pytest +from libs.context import Context +from libs.refresh_token import AirflowTokenRefresher +from libs.upload_file import GCSFileUploader + + +class TestSourceFileChecker: + + @pytest.fixture() + def file_uploader(self, monkeypatch): + context = Context(data_partition_id="test", app_key="") + file_uploader = GCSFileUploader("http://test", AirflowTokenRefresher(), + context) + monkeypatch.setattr(file_uploader, "_get_signed_url_request", + lambda *args, **kwargs: ("test", "test")) + monkeypatch.setattr(file_uploader, "_upload_file_request", + lambda *args, **kwargs: None) + monkeypatch.setattr(file_uploader, "_get_file_location_request", + lambda *args, **kwargs: "test") + return file_uploader + + def test_get_file_from_bucket( + self, + monkeypatch, + file_uploader: GCSFileUploader + ): + file = io.RawIOBase() + monkeypatch.setattr(file_uploader, "get_file_from_bucket", lambda *args, **kwargs: file) + file_uploader.upload_file("gs://test/test") + + @pytest.mark.parametrize( + "file_path", + [ + pytest.param("gs://test"), + pytest.param("://test"), + pytest.param("test"), + ] + ) + def test_invalid_gcs_object_uri(self, file_uploader: GCSFileUploader, + file_path: str): + with pytest.raises(GCSObjectURIError): + file_uploader._parse_object_uri(file_path) diff --git a/tests/plugin-unit-tests/test_manifest_processor_r3.py b/tests/plugin-unit-tests/test_manifest_processor_r3.py index 5a696615f32e6befba8c9f0d43cbefae40610a25..9bc83af9bcb1f8e5aa865ba194e56fcffdb68c8e 100644 --- a/tests/plugin-unit-tests/test_manifest_processor_r3.py +++ b/tests/plugin-unit-tests/test_manifest_processor_r3.py @@ -23,6 +23,8 @@ sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") from libs.context import Context +from libs.upload_file import GCSFileUploader +from libs.refresh_token import AirflowTokenRefresher from libs.exceptions import EmptyManifestError from deepdiff import DeepDiff import pytest @@ -78,12 +80,19 @@ class TestManifestProcessor: with open(conf_path) as f: conf = json.load(f) context = Context.populate(conf) + token_refresher = AirflowTokenRefresher() + file_uploader = GCSFileUploader("test", token_refresher, context) manifest_processor = process_manifest_r3.ManifestProcessor( storage_url="", dagrun_conf=conf, - context=context + token_refresher=token_refresher, + context=context, + file_uploader = file_uploader ) + monkeypatch.setattr(manifest_processor, "generate_id", lambda manifest: "test_id") + monkeypatch.setattr(file_uploader, "upload_file", + lambda *args, **kwargs: "test") return manifest_processor @pytest.fixture() @@ -240,7 +249,9 @@ class TestManifestProcessor: manifest_processor = process_manifest_r3.ManifestProcessor( storage_url="", dagrun_conf=conf, - context=context + token_refresher=AirflowTokenRefresher(), + context=context, + file_uploader=GCSFileUploader("test", AirflowTokenRefresher(), context) ) for manifest_part in manifest_processor.data_object["manifest"]: group_type = manifest_part["groupType"] diff --git a/tests/plugin-unit-tests/test_operators_r3.py b/tests/plugin-unit-tests/test_operators_r3.py index e184a5ee80b3e10f7d0ae189727205ed75fa42de..2b34ce61960ac65288a6fdab594ac13d1211eb60 100644 --- a/tests/plugin-unit-tests/test_operators_r3.py +++ b/tests/plugin-unit-tests/test_operators_r3.py @@ -37,8 +37,9 @@ from file_paths import ( from operators.process_manifest_r3 import ProcessManifestOperatorR3, SchemaValidator, \ ManifestProcessor from operators.search_record_id import SearchRecordIdOperator -from hooks.http_hooks import workflow_hook from operators.update_status import UpdateStatusOperator +from libs.upload_file import GCSFileUploader +from libs.refresh_token import AirflowTokenRefresher from mock_responses import MockSearchResponse, MockWorkflowResponse CustomOperator = TypeVar("CustomOperator") @@ -72,6 +73,9 @@ class TestOperators(object): monkeypatch.setattr(SchemaValidator, "validate_manifest", lambda obj: None) monkeypatch.setattr(ManifestProcessor, "save_record", lambda obj, headers, request_data: MockStorageResponse()) + monkeypatch.setattr(GCSFileUploader, "upload_file", + lambda *args, **kwargs: "test") + task, context = self._create_task(ProcessManifestOperatorR3) task.pre_execute(context) task.execute(context) diff --git a/tests/plugin-unit-tests/test_schema_validator_r3.py b/tests/plugin-unit-tests/test_schema_validator_r3.py index ee258930cc2b626d4da5db18e73bd1a98a153f8d..bae115d64e9f77b5e10b1f4e9a6e2f35531f9b12 100644 --- a/tests/plugin-unit-tests/test_schema_validator_r3.py +++ b/tests/plugin-unit-tests/test_schema_validator_r3.py @@ -36,6 +36,7 @@ from file_paths import ( ) from mock_responses import MockSchemaResponse from libs.context import Context +from libs.refresh_token import AirflowTokenRefresher from libs.exceptions import EmptyManifestError, NotOSDUShemaFormatError import pytest @@ -55,6 +56,7 @@ class TestSchemaValidator: validator = SchemaValidator( "", conf, + AirflowTokenRefresher(), context ) if schema_file: diff --git a/tests/plugin-unit-tests/test_search_record_id.py b/tests/plugin-unit-tests/test_search_record_id.py index 63db7f5569468769984d02330b1345edd058133d..a7faccbe74dfdc17facbd05ca00ac68f8d5f58a9 100644 --- a/tests/plugin-unit-tests/test_search_record_id.py +++ b/tests/plugin-unit-tests/test_search_record_id.py @@ -29,6 +29,7 @@ from file_paths import ( ) from libs.exceptions import RecordsNotSearchableError from libs.context import Context +from libs.refresh_token import AirflowTokenRefresher from tenacity import stop_after_attempt from libs.search_record_ids import SearchId from mock_responses import MockSearchResponse @@ -66,7 +67,8 @@ class TestManifestProcessor: def test_search_found_all_records(self, monkeypatch, record_ids: list, search_response_path: str): self.mock_storage_response(monkeypatch, search_response_path, total_count=len(record_ids)) - id_searcher = SearchId("http://test", record_ids, Context(app_key="", data_partition_id="")) + id_searcher = SearchId("http://test", record_ids, AirflowTokenRefresher(), + Context(app_key="", data_partition_id="")) id_searcher.check_records_searchable() @pytest.mark.parametrize( @@ -87,7 +89,8 @@ class TestManifestProcessor: invalid_total_count = len(record_ids) - 1 self.mock_storage_response(monkeypatch, search_response_path, total_count=invalid_total_count) - id_searcher = SearchId("", record_ids, Context(app_key="", data_partition_id="")) + id_searcher = SearchId("", record_ids, AirflowTokenRefresher(), + Context(app_key="", data_partition_id="")) with pytest.raises(RecordsNotSearchableError): id_searcher.check_records_searchable() @@ -107,7 +110,8 @@ class TestManifestProcessor: def test_search_got_wrong_response_value(self, monkeypatch, record_ids: list, search_response_path: str): self.mock_storage_response(monkeypatch, search_response_path) - id_searcher = SearchId("http://test", record_ids, Context(app_key="", data_partition_id="")) + id_searcher = SearchId("http://test", record_ids, AirflowTokenRefresher(), + Context(app_key="", data_partition_id="")) with pytest.raises(ValueError): id_searcher.check_records_searchable() @@ -124,4 +128,5 @@ class TestManifestProcessor: search_response_path: str): self.mock_storage_response(monkeypatch, search_response_path) with pytest.raises(ValueError): - SearchId("http://test", record_ids, Context(app_key="", data_partition_id="")) + SearchId("http://test", record_ids, AirflowTokenRefresher(), + Context(app_key="", data_partition_id="")) diff --git a/tests/plugin-unit-tests/test_update_status_r3.py b/tests/plugin-unit-tests/test_update_status_r3.py index df05b36f845fa74e2e1fb89f4dcdd5a1599317b5..1dfd85fe347ab662d509a92c79c040f7f37763d7 100644 --- a/tests/plugin-unit-tests/test_update_status_r3.py +++ b/tests/plugin-unit-tests/test_update_status_r3.py @@ -28,6 +28,7 @@ from file_paths import ( MANIFEST_WELLBORE_VALID_PATH ) from libs.context import Context +from libs.refresh_token import AirflowTokenRefresher from libs.update_status import UpdateStatus from mock_responses import MockWorkflowResponse @@ -43,6 +44,7 @@ class TestUpdateStatus: status_updater = UpdateStatus( workflow_url = "http://test", workflow_id=workflow_id, + token_refresher=AirflowTokenRefresher(), context=context, status=status ) diff --git a/tests/set_airflow_env.sh b/tests/set_airflow_env.sh index 8b750c31184db33e83d44d25b876e0f41e1c6eaa..3bdcc1f2a813bf41f0568fd7459d88af95411807 100755 --- a/tests/set_airflow_env.sh +++ b/tests/set_airflow_env.sh @@ -48,6 +48,7 @@ airflow variables -s provider gcp airflow variables -s record_kind "odes:osdu:file:0.2.0" airflow variables -s schema_version "0.2.0" airflow variables -s workflow_url $WORKFLOW_URL +airflow variables -s file_service_url $LOCALHOST airflow variables -s update_status_url $UPDATE_STATUS_URL airflow variables -s search_url $SEARCH_URL airflow variables -s schema_service_url $LOCALHOST