diff --git a/src/dags/libs/utils.py b/src/dags/libs/utils.py index 50d6a4eb38009fb5a1b8af4f77b03a9947379239..83513d8170f1ceeee91b4ff8b86c895ffde11dff 100644 --- a/src/dags/libs/utils.py +++ b/src/dags/libs/utils.py @@ -16,7 +16,11 @@ """Util functions to work with OSDU Manifests.""" from typing import Tuple +import re +# same regex as in core common +RECORD_ID_REGEX = r"^[\w\-\.]+:[\w-\.]+:[\w\-\.\:\%]+$" +RECORD_ID_WITH_VERSION_REGEX = r"^[\w\-\.]+:[\w\-\.]+:[\w\-\.\:\%]+:[0-9]+$" def remove_trailing_colon(id_value: str) -> str: """ @@ -31,15 +35,22 @@ def remove_trailing_colon(id_value: str) -> str: def split_id(id_value: str) -> Tuple[str, str]: """ Get id without a version for searching later. + Record ids structure should have four pieces, the last being the version number :id_value: ID of some entity with or without versions. """ + + _id = id_value version = "" + + # make sure trailing colons get removed if id_value.endswith(":"): _id = id_value[:-1] - elif id_value.split(":")[-1].isdigit(): - version = str(id_value.split(":")[-1]) - _id = id_value[:-len(version) - 1] - else: - _id = id_value + version = "" + + if re.match(RECORD_ID_WITH_VERSION_REGEX, id_value): + id_parts = id_value.split(":") + _id = id_parts[0] + ":" + id_parts[1] + ":" + id_parts[2] + version = id_parts[3] + return _id, version diff --git a/src/dags/providers/aws/__init__.py b/src/dags/providers/aws/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79348898f3b74d25d88c0adc672a0783fc0836ee --- /dev/null +++ b/src/dags/providers/aws/__init__.py @@ -0,0 +1,13 @@ +# Copyright © Amazon Web Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/dags/providers/aws/aws_blob_storage_client.py b/src/dags/providers/aws/aws_blob_storage_client.py new file mode 100644 index 0000000000000000000000000000000000000000..d13d3d7ca3a62d203b3edee1b04ec28a97b01ad2 --- /dev/null +++ b/src/dags/providers/aws/aws_blob_storage_client.py @@ -0,0 +1,141 @@ +# Copyright © Amazon Web Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blob storage AWS client module""" + +# import tenacity +from manifest_ingestion.providers.constants import AWS_CLOUD_PROVIDER +# import logging +# import boto3 +# from boto3.exceptions import ClientError +from manifest_ingestion.providers.factory import ProvidersFactory +from manifest_ingestion.providers.types import BlobStorageClient, FileLikeObject +from typing import Tuple +# import io + +# logger = logging.getLogger(__name__) + +# RETRY_SETTINGS = { +# "stop": tenacity.stop_after_attempt(3), +# "wait": tenacity.wait_fixed(10), +# "reraise": True, +# } + +@ProvidersFactory.register(AWS_CLOUD_PROVIDER) +class AwsCloudStorageClient(BlobStorageClient): + """Implementation of blob storage client for the AWS provider.""" + def __init__(self): + """Initialize storage client.""" + pass + + def does_file_exist(self, uri: str) -> bool: + """Verify if a file exists in the given URI. + + :param uri: The AWS URI of the file. + :type uri: str + :return: A boolean indicating if the file exists + :rtype: bool + """ + + # assuming the URI here is an s3:// URI + # get the bucket name and path to object + # bucket_name, object_name = self._split_s3_path(uri) + # s3_client = boto3.client('s3') + # try: + # # try to get the s3 metadata for the object, which is a + # # fast operation no matter the size of the data object + # s3_client.head_object(Bucket=bucket_name, Key=object_name) + # except ClientError: + # return False + # return True + pass + + def download_to_file(self, uri: str, file: FileLikeObject) -> Tuple[FileLikeObject, str]: + """Download file from the given URI. + + :param uri: The AWS URI of the file. + :type uri: str + :param file: The file where to download the blob content + :type file: FileLikeObject + :return: A tuple containing the file and its content-type + :rtype: Tuple[io.BytesIO, str] + """ + # assuming the URI here is an s3:// URI + # get the bucket name, path to object + # bucket_name, object_name = self._split_s3_path(uri) + # s3_client = boto3.client('s3') + # buffer = io.BytesIO() + # content_type = "" + # file_bytes = s3_client.download_fileobj(Bucket=bucket_name, Key=object_name, Filename=buffer) + + # return file_bytes, content_type + pass + + + def download_file_as_bytes(self, uri: str) -> Tuple[bytes, str]: + """Download file as bytes from the given URI. + + :param uri: The AWS URI of the file + :type uri: str + :return: The file as bytes and its content-type + :rtype: Tuple[bytes, str] + """ + # assuming the URI here is an s3:// URI + # get the bucket name, path to object + # bucket_name, object_name = self._split_s3_path(uri) + # filename = object_name.split('/')[-1] + # s3_client = boto3.client('s3') + # file_handle = io.BytesIO + # return s3_client.download_fileobj(Bucket=bucket_name, Key=object_name, file_handle) + pass + + def upload_file(self, uri: str, blob_file: FileLikeObject, content_type: str): + """Upload a file to the given uri. + + :param uri: The AWS URI of the file + :type uri: str + :param blob: The file + :type blob: FileLikeObject + :param content_type: [description] + :type content_type: str + """ + # assuming the URI here is an s3:// URI + # get the bucket name, path to object + # bucket_name, object_name = self._split_s3_path(uri) + + # s3_client = boto3.client( + # 's3', + # aws_access_key_id=access_key_id, + # aws_secret_access_key=secret_key, + # aws_session_token=session_token + # ) + + # # Upload the file like object + # s3_client.upload_fileobj(blob_file, bucket_name, object_name) + pass + + def _split_s3_path(self, s3_path:str): + """split a s3:// path into bucket and key parts + + Args: + s3_path (str): an s3:// uri + + Returns: + tuple: bucket name, key name ( with path ) + """ + # path_parts=s3_path.replace("s3://","").split("/") + # bucket=path_parts.pop(0) + # key="/".join(path_parts) + # return bucket, key + pass \ No newline at end of file diff --git a/src/dags/providers/aws/aws_credentials.py b/src/dags/providers/aws/aws_credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..14b720dc8330f12f03c7fb28da75147da6a82371 --- /dev/null +++ b/src/dags/providers/aws/aws_credentials.py @@ -0,0 +1,55 @@ +# Copyright © Amazon Web Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AWS Credential Module.""" + +import logging +from manifest_ingestion.providers.constants import AWS_CLOUD_PROVIDER +from manifest_ingestion.providers.factory import ProvidersFactory +from manifest_ingestion.providers.types import BaseCredentials +from manifest_ingestion.providers.aws.service_principal_util import get_service_principal_token +from tenacity import retry, stop_after_attempt +import os + +logger = logging.getLogger(__name__) +RETRIES = 3 + +@ProvidersFactory.register(AWS_CLOUD_PROVIDER) +class AWSCredentials(BaseCredentials): + """AWS Credential Provider""" + + def __init__(self): + """Initialize AWS Credentials object""" + self._access_token = None + + @retry(stop=stop_after_attempt(RETRIES)) + def refresh_token(self) -> str: + """Refresh token. + + :return: Refreshed token + :rtype: str + """ + token = get_service_principal_token() + self._access_token = token + return self._access_token + + @property + def access_token(self) -> str: + """The access token. + + :return: Access token string. + :rtype: str + """ + return self._access_token + diff --git a/src/dags/providers/aws/service_principal_util.py b/src/dags/providers/aws/service_principal_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6c11f6c7e59f0d5ffd047ef4a3768c8f918bebd0 --- /dev/null +++ b/src/dags/providers/aws/service_principal_util.py @@ -0,0 +1,96 @@ +# Copyright © Amazon Web Services +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import base64 +import boto3 +import requests +import json +from botocore.exceptions import ClientError +from configparser import ConfigParser + +def _get_ssm_parameter(session, region_name, ssm_path): + ssm_client = session.client('ssm', region_name=region_name) + ssm_response = ssm_client.get_parameter(Name=ssm_path) + return ssm_response['Parameter']['Value'] + +def _get_secret(session, region_name, secret_name, secret_dict_key): + # Create a Secrets Manager client + client = session.client( + service_name='secretsmanager', + region_name=region_name + ) + + # In this sample we only handle the specific exceptions for the 'GetSecretValue' API. + # See https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + # We rethrow the exception by default. + + try: + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + except ClientError as e: + print("Could not get client secret from secrets manager") + raise e + else: + # Decrypts secret using the associated KMS CMK. + # Depending on whether the secret is a string or binary, one of these fields will be populated. + if 'SecretString' in get_secret_value_response: + secret = get_secret_value_response['SecretString'] + else: + decoded_binary_secret = base64.b64decode(get_secret_value_response['SecretBinary']) + + return_secret_serialized = secret + if return_secret_serialized == None: + return_secret_serialized = decoded_binary_secret + + return_secret = json.loads(return_secret_serialized)[secret_dict_key] + + return return_secret + + +def get_service_principal_token(): + config_parser = ConfigParser(os.environ) + config_file_name = 'osdu_api.ini' + + found_names = config_parser.read(config_file_name) + if config_file_name not in found_names: + raise Exception('Could not find osdu_api.ini config file') + + client_id_ssm_path = config_parser.get('provider', 'client_id_ssm_path') + client_secret_name = config_parser.get('provider', 'client_secret_name') + client_secret_dict_key = config_parser.get('provider', 'client_secret_dict_key') + aws_oauth_custom_scope_ssm_path = config_parser.get('provider', 'aws_oauth_custom_scope_ssm_path') + region_name = config_parser.get('provider', 'region_name') + token_url_ssm_path = config_parser.get('provider', 'token_url_ssm_path') + + session = boto3.session.Session() + client_id = _get_ssm_parameter(session, region_name, client_id_ssm_path) + client_secret = _get_secret(session, region_name, client_secret_name, client_secret_dict_key) + token_url = _get_ssm_parameter(session, region_name, token_url_ssm_path) + aws_oauth_custom_scope = _get_ssm_parameter(session, region_name, aws_oauth_custom_scope_ssm_path) + + auth = '{}:{}'.format(client_id, client_secret) + encoded_auth = base64.b64encode(str.encode(auth)) + + headers = {} + headers['Authorization'] = 'Basic ' + encoded_auth.decode() + headers['Content-Type'] = 'application/x-www-form-urlencoded' + + token_url = '{}?grant_type=client_credentials&client_id={}&scope={}'.format(token_url,client_id, aws_oauth_custom_scope) + + response = requests.post(url=token_url, headers=headers) + # return 'Bearer {}'.format(json.loads(response.content.decode())['access_token']) + print(json.loads(response.content.decode())['access_token']) + return json.loads(response.content.decode())['access_token'] diff --git a/tests/plugin-unit-tests/test_utils.py b/tests/plugin-unit-tests/test_utils.py index 0898ea3fefa2c00aa82691ba94357f464dad3cb4..c2bc7a2a5f4277eb5007d0595b536204b0633a7a 100644 --- a/tests/plugin-unit-tests/test_utils.py +++ b/tests/plugin-unit-tests/test_utils.py @@ -20,7 +20,6 @@ import pytest sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins") sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags") - from libs.utils import split_id, remove_trailing_colon @@ -30,12 +29,12 @@ class TestUtils: "_id,expected_id", [ pytest.param( - "test:test:", - "test:test", + "osdu:master-data--Wellbore:1000:", + "osdu:master-data--Wellbore:1000", id="Trailing colon"), pytest.param( - "test:test", - "test:test", + "osdu:master-data--Wellbore:1000", + "osdu:master-data--Wellbore:1000", id="With no colon") ] ) @@ -46,16 +45,16 @@ class TestUtils: "_id,expected_id_version", [ pytest.param( - "test:test:", - ("test:test", ""), + "osdu:master-data--Wellbore:1000:", + ("osdu:master-data--Wellbore:1000", ""), id="Trailing colon"), pytest.param( - "test:test", - ("test:test", ""), + "osdu:master-data--Wellbore:1000", + ("osdu:master-data--Wellbore:1000", ""), id="With no colon"), pytest.param( - "test:test:1", - ("test:test", "1"), + "osdu:master-data--Wellbore:1000:1", + ("osdu:master-data--Wellbore:1000", "1"), id="With version") ] )