Commit eab1c62a authored by Spencer Sutton's avatar Spencer Sutton
Browse files

Boto client factory, ingestion workflow client

commit d1a8ecb2 
Author: Spencer Sutton <suttonsp@amazon.com> 
Date: Thu Aug 26 2021 13:45:37 GMT-0500 (Central Daylight Time) 

    Adding back tests


commit 7a79138a 
Author: Spencer Sutton <suttonsp@amazon.com> 
Date: Thu Aug 26 2021 08:57:32 GMT-0500 (Central Daylight Time) 

    Incrementing to version 9


commit 245d1d25 
Author: Spencer Sutton <suttonsp@amazon.com> 
Date: Thu Aug 26 2021 08:47:08 GMT-0500 (Central Daylight Time) 

    Adding explicit constructor to child class


commit 185bbfef 
Author: Spencer Sutton <suttonsp@amazon.com> 
Date: Wed Aug 25 2021 17:15:20 GMT-0500 (Central Daylight Time) 

    Fixing base client now


commit 7f0e617a 
Author: Spencer Sutton <suttonsp@amazon.com> 
Date: Wed Aug 25 2021 16:02:27 GMT-0500 (Central Daylight Time) 

    Fixing service principal util


commit 6b91d436 
Author: Spencer Sutton <suttonsp@amazon.com> 
Date: Wed Aug 25 2021 12:25:42 GMT-0500 (Central Daylight Time) 

    Had env var in wrong spot...
parent 1886ceaf
......@@ -31,6 +31,7 @@ env:
CUSTOM_SCOPE: stub
ENVIRONMENT: stub
AWS_REGION: stub
CI_COMMIT_TAG: v0.0.9
phases:
install:
......@@ -52,9 +53,10 @@ phases:
- rm osdu_api.ini
# publish new artifact to code artifact
- aws codeartifact login --tool twine --domain osdu-dev --domain-owner 888733619319 --repository osdu-python
- export AWS_ACCOUNT_ID=`aws sts get-caller-identity | grep Account | cut -d':' -f 2 | cut -d'"' -f 2`
- aws codeartifact login --tool twine --domain osdu-dev --domain-owner ${AWS_ACCOUNT_ID} --repository osdu-python
- python setup.py sdist bdist_wheel
- twine upload --skip-existing --repository codeartifact dist/osdu_api-0.0.6.tar.gz
- twine upload --skip-existing --repository codeartifact dist/osdu_api-0.0.9.tar.gz
artifacts:
......
......@@ -70,6 +70,7 @@ class BaseClient:
self.dataset_url = config_parser.get('environment', 'dataset_url')
self.use_service_principal = config_parser.get('environment', 'use_service_principal')
self.schema_url = config_parser.get('environment', 'schema_url')
self.ingestion_workflow_url = config_parser.get('environment', 'ingestion_workflow_url')
self.provider = config_parser.get('provider', 'name')
self.service_principal_module_name = config_parser.get('provider', 'service_principal_module_name')
......@@ -86,13 +87,16 @@ class BaseClient:
Makes a request using python's built in requests library. Takes additional headers if
necessary
"""
if bearer_token is None:
bearer_token = self.service_principal_token
if bearer_token is not None and 'Bearer ' not in bearer_token:
bearer_token = 'Bearer ' + bearer_token
headers = {
'content-type': 'application/json',
'data-partition-id': self.data_partition_id,
'Authorization': bearer_token if bearer_token is not None else self.service_principal_token
'Authorization': bearer_token
}
if (len(add_headers) > 0):
......
......@@ -23,6 +23,8 @@ class DatasetDmsClient(BaseClient):
"""
Holds the logic for interfacing with Data Registry Service's DMS api
"""
def __init__(self, data_partition_id=None):
super().__init__(data_partition_id)
def get_storage_instructions(self, kind_sub_type: str, bearer_token=None):
params = {'kindSubType': kind_sub_type}
......
# Copyright © 2020 Amazon Web Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import List
from osdu_api.clients.base_client import BaseClient
from osdu_api.model.ingestion_workflow.create_workflow_request import CreateWorkflowRequest
from osdu_api.model.ingestion_workflow.trigger_workflow_request import TriggerWorkflowRequest
from osdu_api.model.ingestion_workflow.update_workflow_run_request import UpdateWorkflowRunRequest
from osdu_api.model.http_method import HttpMethod
class IngestionWorkflowClient(BaseClient):
"""
Holds the logic for interfacing with Ingestion Workflow's api
"""
def get_workflow(self, workflow_name: str, bearer_token=None):
params = {'workflowName': workflow_name}
return self.make_request(method=HttpMethod.GET, url='{}{}'.format(self.ingestion_workflow_url, '/workflow'),
params=params, bearer_token=bearer_token)
def create_workflow(self, create_workflow_request: CreateWorkflowRequest, bearer_token=None):
return self.make_request(method=HttpMethod.POST, url='{}{}'.format(self.ingestion_workflow_url, '/workflow'),
data=create_workflow_request.to_JSON(), bearer_token=bearer_token)
def get_all_workflows_in_partition(self, bearer_token=None):
return self.make_request(method=HttpMethod.GET, url='{}{}'.format(self.ingestion_workflow_url, '/workflow'),
bearer_token=bearer_token)
def delete_workflow(self, workflow_name: str, bearer_token=None):
params = {'workflowName': workflow_name}
return self.make_request(method=HttpMethod.DELETE, url='{}{}'.format(self.ingestion_workflow_url, '/workflow'),
params=params, bearer_token=bearer_token)
def trigger_workflow(self, trigger_workflow_request: TriggerWorkflowRequest, workflow_name: str, bearer_token=None):
return self.make_request(method=HttpMethod.POST, url='{}{}{}{}'.format(self.ingestion_workflow_url, '/workflow/', workflow_name, '/workflowRun'),
data=trigger_workflow_request.to_JSON(), bearer_token=bearer_token)
def get_workflow_runs(self, workflow_name: str, bearer_token=None):
return self.make_request(method=HttpMethod.GET, url='{}{}{}{}'.format(self.ingestion_workflow_url, '/workflow/', workflow_name, '/workflowRun'),
bearer_token=bearer_token)
def get_workflow_run_by_id(self, workflow_name: str, run_id: str, bearer_token=None):
return self.make_request(method=HttpMethod.GET, url='{}{}{}{}{}'.format(self.ingestion_workflow_url, '/workflow/', workflow_name, '/workflowRun/', run_id),
bearer_token=bearer_token)
def update_workflow_run(self, update_workflow_run_request: UpdateWorkflowRunRequest, workflow_name: str, run_id: str, bearer_token=None):
return self.make_request(method=HttpMethod.PUT, url='{}{}{}{}{}'.format(self.ingestion_workflow_url, '/workflow/', workflow_name, '/workflowRun/', run_id),
data=update_workflow_run_request.to_JSON(), bearer_token=bearer_token)
\ No newline at end of file
......@@ -79,3 +79,17 @@ class RecordClient(BaseClient):
def query_records(self, query_records_request: QueryRecordsRequest, bearer_token = None):
return self.make_request(method=HttpMethod.POST, url='{}{}'.format(self.storage_url, '/query/records'),
data=query_records_request.to_JSON(), bearer_token=bearer_token)
#ingest bulk records which is coming as JSON response -- Start
def ingest_records(self, records, bearer_token = None):
"""
Calls storage's api endpoint createOrUpdateRecords taking a list of record objects and constructing
the body of the request
Returns the response object for the call
Example of code to new up a record:
"""
return self.make_request(method=HttpMethod.POST, url='{}{}'.format(self.data_workflow_url, '/workflowRun'), data=records, bearer_token=bearer_token)
def query_record(self, recordId: str, bearer_token = None):
return self.make_request(method=HttpMethod.GET, url=('{}{}/{}'.format(self.storage_url, '/records', recordId)), bearer_token=bearer_token)
from osdu_api.model.ingestion_workflow.create_workflow_request import CreateWorkflowRequest
from osdu_api.model.ingestion_workflow.trigger_workflow_request import TriggerWorkflowRequest
from osdu_api.model.ingestion_workflow.update_workflow_run_request import UpdateWorkflowRunRequest
from osdu_api.clients.ingestion_workflow.ingestion_workflow_client import IngestionWorkflowClient
ingestion_client = IngestionWorkflowClient()
create_workflow_request = CreateWorkflowRequest("test description", {}, "my_second_dag")
response = ingestion_client.create_workflow(create_workflow_request)
print(">>>>>>>")
print(response.status_code)
print(response.content)
if response.status_code == 200 or response.status_code == 409:
response = ingestion_client.get_workflow("my_second_dag")
print(">>>>>>>")
print(response.status_code)
print(response.content)
response = ingestion_client.get_all_workflows_in_partition()
print(">>>>>>>")
print(response.status_code)
print(response.content)
response = ingestion_client.delete_workflow("my_second_dag")
print(">>>>>>>")
print(response.status_code)
print(response.content)
create_workflow_request = CreateWorkflowRequest("test description", {}, "my_second_dag")
response = ingestion_client.create_workflow(create_workflow_request)
if response.status_code == 200 or response.status_code == 409:
trigger_workflow_request = TriggerWorkflowRequest({})
response = ingestion_client.trigger_workflow(trigger_workflow_request, "my_second_dag")
print(">>>>>>>")
print(response.status_code)
print(response.content)
response = ingestion_client.get_workflow_runs("my_second_dag")
print(">>>>>>>")
print(response.status_code)
print(response.content)
......@@ -20,6 +20,7 @@
dataset_url=%(BASE_URL)s/api/dataset/v1
entitlements_url=%(BASE_URL)s/api/entitlements/v1
schema_url=%(BASE_URL)s/api/schema-service/v1
ingestion_workflow_url=%(BASE_URL)s/api/workflow/v1
use_service_principal=True
[provider]
......
# Copyright © 2020 Amazon Web Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from osdu_api.model.base import Base
class CreateWorkflowRequest(Base):
"""
Request body to ingestion workflow's create workflow endpoint
"""
def __init__(self, description: str, registration_instructions: dict, workflow_name: str):
self.description = description
self.registrationInstructions = registration_instructions
self.workflowName = workflow_name
# Copyright © 2020 Amazon Web Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from osdu_api.model.base import Base
class TriggerWorkflowRequest(Base):
"""
Request body to ingestion workflow's trigger workflow endpoint
"""
def __init__(self, execution_context: dict):
self.executionContext = execution_context
# Copyright © 2020 Amazon Web Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from osdu_api.model.base import Base
class UpdateWorkflowRunRequest(Base):
"""
Request body to ingestion workflow's update workflow run endpoint
"""
def __init__(self, status: str):
self.status = status
import os
import boto3
class BotoClientFactory:
def get_boto_client(self, client_type: str, region_name: str):
session = boto3.session.Session(region_name=region_name)
client = session.client(
service_name=client_type,
region_name=region_name
)
if 'USE_AIRFLOW' in os.environ and (os.environ['USE_AIRFLOW'] == 'true' or os.environ['USE_AIRFLOW'] == 'True' or os.environ['USE_AIRFLOW'] == True):
sts_client = boto3.client('sts')
if 'AWS_ROLE_ARN' not in os.environ:
raise Exception('Must have AWS_ROLE_ARN set')
assumed_role_object=sts_client.assume_role(
RoleArn=os.environ['AWS_ROLE_ARN'],
RoleSessionName="airflow_session"
)
credentials=assumed_role_object['Credentials']
session = boto3.Session(
aws_access_key_id=credentials['AccessKeyId'],
aws_secret_access_key=credentials['SecretAccessKey'],
aws_session_token=credentials['SessionToken']
)
client = session.client(
service_name=client_type,
region_name=region_name
)
return client
......@@ -18,18 +18,17 @@ import requests
import json
from botocore.exceptions import ClientError
from configparser import ConfigParser
from osdu_api.providers.aws.boto_client_factory import BotoClientFactory
def _get_ssm_parameter(session, region_name, ssm_path):
ssm_client = session.client('ssm', region_name=region_name)
ssm_response = ssm_client.get_parameter(Name=ssm_path)
return ssm_response['Parameter']['Value']
def _get_ssm_parameter(region_name, ssm_path):
boto_client_factory = BotoClientFactory()
ssm_client = boto_client_factory.get_boto_client('ssm', region_name)
ssm_response = ssm_client.get_parameter(Name=ssm_path)
return ssm_response['Parameter']['Value']
def _get_secret(session, region_name, secret_name, secret_dict_key):
# Create a Secrets Manager client
client = session.client(
service_name='secretsmanager',
region_name=region_name
)
def _get_secret(region_name, secret_name, secret_dict_key):
boto_client_factory = BotoClientFactory()
client = boto_client_factory.get_boto_client('secretsmanager', region_name)
# In this sample we only handle the specific exceptions for the 'GetSecretValue' API.
# See https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
......@@ -74,11 +73,10 @@ def get_service_principal_token():
region_name = config_parser.get('provider', 'region_name')
token_url_ssm_path = config_parser.get('provider', 'token_url_ssm_path')
session = boto3.session.Session()
client_id = _get_ssm_parameter(session, region_name, client_id_ssm_path)
client_secret = _get_secret(session, region_name, client_secret_name, client_secret_dict_key)
token_url = _get_ssm_parameter(session, region_name, token_url_ssm_path)
aws_oauth_custom_scope = _get_ssm_parameter(session, region_name, aws_oauth_custom_scope_ssm_path)
client_id = _get_ssm_parameter(region_name, client_id_ssm_path)
client_secret = _get_secret(region_name, client_secret_name, client_secret_dict_key)
token_url = _get_ssm_parameter(region_name, token_url_ssm_path)
aws_oauth_custom_scope = _get_ssm_parameter(region_name, aws_oauth_custom_scope_ssm_path)
auth = '{}:{}'.format(client_id, client_secret)
encoded_auth = base64.b64encode(str.encode(auth))
......@@ -90,6 +88,4 @@ def get_service_principal_token():
token_url = '{}?grant_type=client_credentials&client_id={}&scope={}'.format(token_url,client_id, aws_oauth_custom_scope)
response = requests.post(url=token_url, headers=headers)
# return 'Bearer {}'.format(json.loads(response.content.decode())['access_token'])
print(json.loads(response.content.decode())['access_token'])
return json.loads(response.content.decode())['access_token']
......@@ -21,6 +21,7 @@ entitlements_url=%(ENTITLEMENTS_BASE_URL)s/api/entitlements/v1
file_dms_url=%(FILE_DMS_BASE_URL)s/api/filedms/v2
dataset_url=%(DATASET_REGISTRY_BASE_URL)s/api/dataset-registry/v1
schema_url=%(SCHEMA_BASE_URL)s/api/schema-service/v1
ingestion_workflow_url=stub
use_service_principal=True
[provider]
......
# Copyright © 2020 Amazon Web Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import mock
from osdu_api.clients.ingestion_workflow.ingestion_workflow_client import IngestionWorkflowClient
from osdu_api.clients.base_client import BaseClient
from osdu_api.model.http_method import HttpMethod
from osdu_api.model.ingestion_workflow.create_workflow_request import CreateWorkflowRequest
class TestIngestionWorkflowClient(unittest.TestCase):
@mock.patch.object(BaseClient, 'make_request', return_value="response")
@mock.patch.object(BaseClient, '_refresh_service_principal_token', return_value="stubbed")
def test_make_request(self, get_bearer_token_mock, make_request_mock):
# Arrange
client = IngestionWorkflowClient()
client.service_principal_token = 'stubbed'
client.ingestion_workflow_url = 'stubbed url'
client.headers = {}
create_workflow_request = CreateWorkflowRequest("test description", {}, "my_second_dag")
# Act
response = client.create_workflow(create_workflow_request)
# Assert
assert response == make_request_mock.return_value
\ No newline at end of file
......@@ -4,5 +4,8 @@ requests==2.25.1
strict-rfc3339==0.7
tenacity==8.0.1
toposort==1.6
pytest
dataclasses==0.8;python_version<"3.7"
mock==4.0.2
responses==0.12.1
twine==3.2.0
......@@ -20,7 +20,9 @@ import setuptools
with open("README.md", "r") as fh:
long_description = fh.read()
COMMIT_BASED_VERSION = "0.0.6"
def get_version_from_file():
with open("VERSION", "r") as fh:
return fh.read().strip()
def prepare_version():
version = os.getenv("CI_COMMIT_TAG", '')
......@@ -32,7 +34,7 @@ def prepare_version():
# we assume that it is commit version
# https://packaging.python.org/guides/distributing-packages-using-setuptools/#local-version-identifiers
commit = os.environ["CI_COMMIT_SHORT_SHA"]
version = f"{COMMIT_BASED_VERSION}.dev+{commit}"
version = f"{get_version_from_file()}.dev+{commit}"
return version
......@@ -50,12 +52,12 @@ setuptools.setup(
install_requires=[
"jsonschema==3.2.0",
"pyyaml==5.4.1",
"requests==2.25.1",
"strict-rfc3339==0.7",
"tenacity==8.0.1",
"toposort==1.6",
"dataclasses==0.8;python_version<'3.7'"
],
extras_require={
"all": ["requests==2.25.1", "tenacity==6.2.0"]
},
python_requires='>=3.6',
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment