base_client.py 8.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright © 2020 Amazon Web Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import importlib
15
import logging
16
import requests
17
from configparser import SafeConfigParser
18

19
from osdu_api.auth.authorization import authorize, TokenRefresher
20
21
from osdu_api.configuration.base_config_manager import BaseConfigManager
from osdu_api.configuration.config_manager import DefaultConfigManager
22
from osdu_api.exceptions.exceptions import MakeRequestError
23
24
from osdu_api.model.http_method import HttpMethod

25

26
class BaseClient:
27
28
29
    """
    Base client that is meant to be extended by service specific clients
    """
30

31
32
33
34
35
36
37
    def __init__(
        self, 
        config_manager: BaseConfigManager = None, 
        data_partition_id = None, 
        token_refresher: TokenRefresher = None,
        logger = None
    ):
38
39
40
41
        """
        Base client gets initialized with configuration values and a bearer token
        based on provider-specific logic
        """
Spencer Sutton's avatar
Merge    
Spencer Sutton committed
42
        self._parse_config(config_manager, data_partition_id)
43
        self.unauth_retries = 0
44
        if self.use_service_principal:
45
            self._refresh_service_principal_token()
46
        
Spencer Sutton's avatar
Spencer Sutton committed
47
48
49
        self.logger = logger
        if self.logger is None:
            self.logger = logging.getLogger(__name__)
50
        self.token_refresher = token_refresher
51

Spencer Sutton's avatar
Merge    
Spencer Sutton committed
52
    def _parse_config(self, config_manager: BaseConfigManager = None, data_partition_id = None):
53
        """
54
        Parse config.
55

56
57
        :param config_manager: ConfigManager to get configs, defaults to None
        :type config_manager: BaseConfigManager, optional
58
        """
59
        self.config_manager = config_manager or DefaultConfigManager()
60

61
        self.provider = self.config_manager.get('provider', 'name')
62

63
64
65
66
67
68
69
70
71
72
73
        self.data_workflow_url = self.config_manager.get('environment', 'data_workflow_url')
        self.dataset_url = self.config_manager.get('environment', 'dataset_url')
        self.entitlements_url = self.config_manager.get('environment', 'entitlements_url')
        self.file_dms_url = self.config_manager.get('environment', 'file_dms_url')
        self.legal_url = self.config_manager.get('environment', 'legal_url')
        self.schema_url = self.config_manager.get('environment', 'schema_url')
        self.search_url = self.config_manager.get('environment', 'search_url')
        self.storage_url = self.config_manager.get('environment', 'storage_url')
        self.partition_url = self.config_manager.get('environment', 'partition_url')
        self.ingestion_workflow_url = self.config_manager.get('environment', 'ingestion_workflow_url')
        self.provider = self.config_manager.get('provider', 'name')
74

75
        self.use_service_principal = self.config_manager.getbool('environment', 'use_service_principal', False)
76
        if self.use_service_principal:
77
            self.service_principal_module_name = self.config_manager.get('provider', 'service_principal_module_name')
78

Spencer Sutton's avatar
Spencer Sutton committed
79
        if data_partition_id is None:
80
            self.data_partition_id = self.config_manager.get('environment', 'data_partition_id')
Spencer Sutton's avatar
Spencer Sutton committed
81
82
        else:
            self.data_partition_id = data_partition_id
83
84

    def _refresh_service_principal_token(self):
85
86
87
88
        """
        The path to the logic to get a valid bearer token is dynamically injected based on
        what provider and entitlements module name is provided in the configuration yaml
        """
Spencer Sutton's avatar
Bug    
Spencer Sutton committed
89
        entitlements_client = importlib.import_module('osdu_api.providers.%s.%s' % (self.provider, self.service_principal_module_name))
90
        self.service_principal_token = entitlements_client.get_service_principal_token()
91

92
93
    @staticmethod
    def _send_request(method: HttpMethod, url: str, data: str, headers: dict, params: dict) -> requests.Response:
94
        """
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        Send request to OSDU

        :param method: HTTP method
        :type method: HttpMethod
        :param url: service's URL
        :type url: str
        :param data: request's data
        :type data: str
        :param headers: request's headers
        :type headers: dict
        :param params: params
        :type params: dict
        :return: response from OSDU service
        :rtype: requests.Response
        """        
110
        if (method == HttpMethod.GET):
111
112
113
            response = requests.get(url=url, params=params, headers=headers, verify=False)
        elif (method == HttpMethod.DELETE):
            response = requests.delete(url=url, params=params, headers=headers, verify=False)
114
        elif (method == HttpMethod.POST):
115
            response = requests.post(url=url, params=params, data=data, headers=headers, verify=False)
116
        elif (method == HttpMethod.PUT):
117
            response = requests.put(url=url, params=params, data=data, headers=headers, verify=False)
118
        return response
119

120
121
122
123
124
125
126
    def _send_request_with_principle_token(
        self, 
        method: HttpMethod, 
        url: str, 
        data: str, 
        headers: dict, 
        params: dict, 
127
    ) -> requests.Response:
128
129
130
131
132
        bearer_token = self.service_principal_token
        if bearer_token is not None and 'Bearer ' not in bearer_token:
            bearer_token = 'Bearer ' + bearer_token

        headers["Authorization"] = bearer_token
133

134
135
136
137
138
139
140
        response = self._send_request(method, url, data, headers, params)

        if (response.status_code == 401 or response.status_code == 403) and self.unauth_retries < 1:
            self.unauth_retries += 1
            self._refresh_service_principal_token()
            self._send_request_with_principle_token(method, url, data, headers, params)
        
141
        self.unauth_retries = 0
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        return response

    def _send_request_with_bearer_token(
        self, 
        method: HttpMethod, 
        url: str, 
        data: str, 
        headers: dict, 
        params: dict, 
        bearer_token: str
    ) -> requests.Response:
        """
        Send request with bearer_token provided by SDK user.

        :param method: HTTP method
        :type method: HttpMethod
        :param url: service's URL
        :type url: str
        :param data: request's data
        :type data: str
        :param headers: request's headers
        :type headers: dict
        :param params: params
        :type params: dict
        :param bearer_token: bearer_token
        :type params: str
        :return: response from OSDU service
        :rtype: requests.Response
        """

        if bearer_token is not None and 'Bearer ' not in bearer_token:
            bearer_token = 'Bearer ' + bearer_token
        headers["Authorization"] = bearer_token

        response = self._send_request(method, url, data, headers, params)
177

178
179
180
181
        if not response.ok:
            response.raise_for_status()

        return response
182

183
    @authorize()
184
185
186
187
188
189
190
191
    def _send_request_with_token_refresher(
        self,
        headers: dict,
        method: HttpMethod,
        url: str,
        data: str,
        params: dict
    ) -> requests.Response:
192
193
194
195
196
197
198
199
200
201
        return self._send_request(method, url, data, headers, params)

    def make_request(
        self, 
        method: HttpMethod, 
        url: str, 
        data = '', 
        add_headers: dict = None, 
        params: dict = None, 
        bearer_token = None
202
    ) -> requests.Response:
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        """
        Makes a request using python's built in requests library. Takes additional headers if
        necessary
        """
        add_headers = add_headers or {}
        params = params or {}

        headers = {
            'content-type': 'application/json',
            'data-partition-id': self.data_partition_id,
        }

        for key, value in add_headers.items():
            headers[key] = value

        if bearer_token:
            response = self._send_request_with_bearer_token(method, url, data, headers, params, bearer_token)
        elif self.token_refresher:
            # _send_request_with_token_refresher has other method signature to work with @authorize decorator
            response = self._send_request_with_token_refresher(headers, method, url, data, params)
        elif self.use_service_principal:
            response = self._send_request_with_principle_token(method, url, data, headers, params) 
        else:
            raise MakeRequestError("There is no strategy to get Bearer token.")
Bill Wang's avatar
Bill Wang committed
227
        return response