Commit 3ac10deb authored by Siarhei Khaletski (EPAM)'s avatar Siarhei Khaletski (EPAM) 🚩
Browse files

Added support of ADC ("Application Default Credentials") for GCP

parent 959891b7
Pipeline #68402 failed with stages
in 1 minute and 30 seconds
......@@ -2,12 +2,19 @@ import aiohttp
import json
from typing import List, Optional
import datetime
import enum
import os
import time
import jwt
from urllib.parse import quote_plus
from urllib.parse import urlencode
from osdu.core.auth import AuthBase
class Type(enum.Enum):
AUTHORIZED_USER = 'authorized_user'
SERVICE_ACCOUNT = 'service_account'
GCE_METADATA = 'gce_metadata'
class GoogleAccountAuth(AuthBase):
_scheme: str = 'Bearer'
......@@ -24,9 +31,9 @@ class GoogleAccountAuth(AuthBase):
:param scopes: scopes
"""
super().__init__()
with open(service_file, 'r') as f:
self.service_data = json.load(f)
self.service_file = service_file
self.service_data = self.get_service_data()
self.session = session
self.scopes = ' '.join(scopes or [])
self.access_token: Optional[str] = None
......@@ -38,6 +45,36 @@ class GoogleAccountAuth(AuthBase):
""" Return scheme """
return self._scheme
@property
def token_type(self):
if self.service_data:
return Type(self.service_data['type'])
else:
return Type.GCE_METADATA
@property
def token_uri(self):
if self.service_data:
return self.service_data.get('token_uri', 'https://oauth2.googleapis.com/token')
else:
return (f'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts'
'/default/token?recursive=true')
def get_service_data(self) -> Optional[dict]:
service_file = self.service_file or os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
if not service_file:
cloudsdk_config = os.environ.get('CLOUDSDK_CONFIG')
sdkpath = (cloudsdk_config
or os.path.join(os.path.expanduser('~'), '.config', 'gcloud'))
service_file = os.path.join(sdkpath, 'application_default_credentials.json')
try:
with open(service_file, 'r') as f:
return json.load(f)
except Exception:
return {} # e.g. for GCE
async def header_value(self) -> str:
token = await self._get_token()
return f'{self._scheme} {token}'
......@@ -56,11 +93,9 @@ class GoogleAccountAuth(AuthBase):
return self.service_data['client_email']
async def _refresh_token_for_service_account(self):
token_uri = self.service_data.get('token_uri', 'https://oauth2.googleapis.com/token')
now = int(time.time())
assertion_payload = {
'aud': token_uri,
'aud': self.token_uri,
'exp': now + 3600,
'iat': now,
'iss': self.service_data['client_email'],
......@@ -78,18 +113,32 @@ class GoogleAccountAuth(AuthBase):
refresh_headers = {'Content-Type': 'application/x-www-form-urlencoded'}
async with self.session.post(token_uri, data=payload, headers=refresh_headers) as resp:
async with self.session.post(self.token_uri, data=payload, headers=refresh_headers) as resp:
if resp.status != 200:
raise Exception() #TODO
content = await resp.json()
# print('from token ------------')
# print(content)
async def _refresh_token_for_authorized_user(self):
payload = urlencode({
'grant_type': 'refresh_token',
'client_id': self.service_data['client_id'],
'client_secret': self.service_data['client_secret'],
'refresh_token': self.service_data['refresh_token'],
})
refresh_headers = {'Content-Type': 'application/x-www-form-urlencoded'}
async with self.session.post(url=self.token_uri, data=payload,
headers=refresh_headers) as resp:
if resp.status != 200:
raise Exception() #TODO
self.access_token = str(content['access_token'])
self.access_token_duration = int(content['expires_in'])
self.access_token_acquired_at = datetime.datetime.utcnow()
async def _refresh_token_for_gce_metadata(self):
refresh_headers = {'metadata-flavor': 'Google'}
async with self.session.get(url=self.token_uri, headers=refresh_headers) as resp:
if resp.status != 200:
raise Exception() #TODO
async def _get_token(self) -> str:
if self.access_token:
now = datetime.datetime.utcnow()
......@@ -97,5 +146,17 @@ class GoogleAccountAuth(AuthBase):
if delta < 3000:
return self.access_token
await self._refresh_token_for_service_account()
if self.token_type == Type.AUTHORIZED_USER:
resp = await self._refresh_token_for_authorized_user()
elif self.token_type == Type.SERVICE_ACCOUNT:
resp = await self._refresh_token_for_service_account()
elif self.token_type == Type.GCE_METADATA:
resp = await self._refresh_token_for_gce_metadata()
content = await resp.json()
self.access_token = str(content['access_token'])
self.access_token_duration = int(content['expires_in'])
self.access_token_acquired_at = datetime.datetime.utcnow()
return self.access_token
......@@ -72,7 +72,6 @@ class GCloudAioStorage(BlobStorageBase):
if forwarded_auth is not None:
return await forwarded_auth.token()
assert self._service_account_file, 'No credentials provided'
token_cache = self._access_token_dict
cache_key = f'{project}_{bucket}'
tenant_access_token = token_cache.get(cache_key, None)
......
......@@ -33,6 +33,14 @@ async def storage_client(request):
yield GCloudAioStorage(session=session, service_account_file=_TESTING_CFG.credentials)
await session.close()
@pytest.fixture(params=['GCloudAioStorage'])
async def storage_client_adc(request):
client_name = request.param
if client_name == 'GCloudAioStorage':
session = aiohttp.ClientSession()
GCloudAioStorage._access_token_dict = {} # clear cache
yield GCloudAioStorage(session=session)
await session.close()
@pytest.fixture
async def test_tenant():
......@@ -56,6 +64,13 @@ async def test_list_objects(storage_client, test_tenant):
assert file_name in result
@pytest.mark.asyncio
async def test_list_objects_adc(storage_client_adc, test_tenant):
result = await storage_client_adc.list_objects(test_tenant)
for file_name, _ in TEST_DATA['initial_files']:
assert file_name in result
@pytest.mark.asyncio
async def test_download(storage_client, test_tenant):
name, expected_content = TEST_DATA['initial_files'][0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment