Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • osdu/platform/domain-data-mgmt-services/wellbore/wellbore-domain-services
  • Vkamani/vkamani-wellbore-domain-services
  • Yan_Sushchynski/wellbore-domain-services-comm-impl
3 results
Show changes
Showing
with 1096 additions and 111 deletions
# Copyright 2021 Schlumberger
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dask.distributed import WorkerPlugin
from app.helper.logger import get_logger, init_logger
from app.conf import Config
class DaskWorkerPlugin(WorkerPlugin):
def __init__(self, logger=None, register_fsspec_implementation=None) -> None:
self.worker = None
self._register_fsspec_implementation = register_fsspec_implementation
super().__init__()
logger.debug("WorkerPlugin initialised")
def setup(self, worker):
init_logger(service_name=Config.service_name.value, config=Config)
self.worker = worker
if self._register_fsspec_implementation:
self._register_fsspec_implementation()
def teardown(self, worker):
get_logger().debug(f"Worker '{worker.name}' with id '{worker.id}' is closing - {worker}")
def transition(self, key, start, finish, *args, **kwargs):
if finish == 'error':
exc = self.worker.exceptions.get(key, None)
get_logger().exception(f"Task '{key}' has failed with exception: {exc}")
from typing import List
import json
import fsspec
import pandas as pd
from .utils import WDMS_INDEX_NAME
from app.model.model_chunking import DataframeBasicDescribe
# imports from bulk_persistence
from ..json_orient import JSONOrient
from ..mime_types import MimeType
from ..dataframe_serializer import DataframeSerializerSync
from ..dataframe_validators import (DataFrameValidationFunc, assert_df_validate, validate_index,
columns_not_in_reserved_names, validate_number_of_columns)
from .errors import BulkNotProcessable, BulkSaveException
from . import storage_path_builder as path_builder
from . import session_file_meta as session_meta
from ..consistency_checks import DataConsistencyChecks
"""
Contains functions related to writing bulk that mean to be run inside worker
"""
def basic_describe(df: pd.DataFrame) -> DataframeBasicDescribe:
full_cols = df.columns.tolist()
if len(full_cols) > 20: # truncate if too many columns, show 10 first and 10 last
cols = [*full_cols[0:10], '...', *full_cols[-10:]]
else:
cols = full_cols
index_exists = len(df.index)
return DataframeBasicDescribe(rowCount=len(df.index),
columnCount=len(full_cols),
columns=cols,
indexStart=str(df.index[0]) if index_exists else "0",
indexEnd=str(df.index[-1]) if index_exists else "0",
indexType=str(df.index.dtype) if index_exists else "")
def write_bulk_without_session(data_handle,
data_getter,
content_type: MimeType,
df_validator_func: DataFrameValidationFunc,
bulk_base_path: str,
storage_options,
consistency_checks: DataConsistencyChecks,
record: "Record",
) -> DataframeBasicDescribe:
"""
process post data outside of a session - write data straight to blob storage
:param data_handle: dataframe as input ipc raw bytes wrapped (file-like obj)
:param data_getter: function to get data from the handle
:param content_type: content type value as mime type (supports json and parquet)
:param df_validator_func: option validation callable function.
:param bulk_base_path: base path of the final object on blob storage.
:param storage_options: storage options
:param consistency_checks: option consistency checks object
:param record: the entity to which the bulk belongs
:return: basic describe of the dataframe
:throw: BulkNotProcessable, BulkSaveException
"""
# 1- deserialize to pandas dataframe
try:
with data_getter(data_handle) as file_like_data:
df = DataframeSerializerSync.load(file_like_data, content_type, JSONOrient.split)
except Exception as e:
raise BulkNotProcessable(f'parsing error: {e}') from e
data_handle = None # unref
# 2- input dataframe validation
assert_df_validate(df, [df_validator_func, validate_number_of_columns, columns_not_in_reserved_names, validate_index])
# 2(bis)- checks data consistency
consistency_checks.check_bulk_consistency_on_post_bulk(record, df)
# set the name of the index column
df.index.name = WDMS_INDEX_NAME
# 3- build blob filename and final full blob path
filename = session_meta.generate_chunk_filename(df)
full_file_path = path_builder.join(bulk_base_path, filename + '.parquet')
# 4- save/upload the dataframe
try:
DataframeSerializerSync.to_parquet(df, full_file_path, storage_options=storage_options)
except Exception as e:
raise BulkSaveException('Unexpected error and save bulk') from e
# 4- return basic describe
return basic_describe(df)
def add_chunk_in_session(data_handle,
data_getter,
content_type: MimeType,
df_validator_func: DataFrameValidationFunc,
record_session_path: str,
storage_options) -> DataframeBasicDescribe:
"""
process add chunk data inside of a session
:param data_handle: input ipc raw bytes wrapped (file-like obj)
:param data_getter: function to get data from the handle
:param content_type: content type as mime type (supports json and parquet)
:param df_validator_func: option validation callable function.
:param record_session_path: base path to the session associated to the record.
:param storage_options: storage options
:return: basic describe of the dataframe
:throw: BulkNotProcessable, BulkSaveException
"""
# 1- deserialize
try:
with data_getter(data_handle) as file_like_data:
df = DataframeSerializerSync.load(file_like_data, content_type, JSONOrient.split)
except Exception as e:
raise BulkNotProcessable(f'parsing error: {e}') from e
data_handle = None # unref
# 2- perf some check
assert_df_validate(df, [df_validator_func, validate_number_of_columns, columns_not_in_reserved_names, validate_index])
# sort column by names and set index column name # TODO could it be avoided ? then we could keep input untouched and save serialization step?
df = df[sorted(df.columns)]
df.index.name = WDMS_INDEX_NAME
# 3- build blob filename and final full blob path
filename = session_meta.generate_chunk_filename(df)
# 4- build and push chunk meta file
meta_file_path, protocol = path_builder.remove_protocol(f'{record_session_path}/{filename}.meta')
fs = fsspec.filesystem(protocol, **(storage_options if storage_options else {}))
with fs.open(meta_file_path, 'w') as outfile:
json.dump(session_meta.build_chunk_metadata(df), outfile)
# 5- save/upload the dataframe
parquet_file_path = f'{record_session_path}/{filename}.parquet'
try:
DataframeSerializerSync.to_parquet(df, parquet_file_path, storage_options=storage_options)
except Exception as e:
raise BulkSaveException('Unexpected error and save bulk') from e
# 6- return basic describe
return basic_describe(df)
......@@ -13,24 +13,113 @@
# limitations under the License.
from fastapi import status, HTTPException
from dask.distributed import scheduler
from pyarrow.lib import ArrowException, ArrowInvalid
from functools import wraps
from app.conf import Config
from app.helper.logger import get_logger
class BulkError(Exception):
class BulkError(RuntimeError):
http_status: int
def raise_as_http(self):
raise HTTPException(status_code=self.http_status, detail=str(self))
raise HTTPException(status_code=self.http_status, detail=str(self)) from self
class BulkRecordNotFound(BulkError):
http_status = status.HTTP_404_NOT_FOUND
def __init__(self, record_id=None, bulk_id=None, message=None):
ex_message = 'bulk '
if bulk_id:
ex_message += f'{bulk_id} '
if record_id:
ex_message += f'for record {record_id} '
ex_message += 'not found'
if message:
ex_message += ': ' + message
super().__init__(ex_message)
class BulkNotFound(BulkError):
class BulkCurvesNotFound (BulkError):
http_status = status.HTTP_404_NOT_FOUND
def __init__(self, record_id, bulk_id):
self.message = f'bulk {bulk_id} for record {record_id} not found'
def __init__(self, curves=None, message=None):
ex_message = 'bulk '
if curves:
ex_message += f'for curves: {curves} not found'
if message:
ex_message += ': ' + message
super().__init__(ex_message)
class BulkNotProcessable(BulkError):
http_status = status.HTTP_422_UNPROCESSABLE_ENTITY
def __init__(self, bulk_id):
self.message = f'bulk {bulk_id} not processable'
def __init__(self, bulk_id=None, message=None):
ex_message = 'bulk '
if bulk_id:
ex_message += f'{bulk_id} '
ex_message += 'not processable'
if message:
ex_message += ': ' + message
super().__init__(ex_message)
class BulkSaveException(BulkError):
http_status = status.HTTP_500_INTERNAL_SERVER_ERROR
class InternalBulkError(BulkError):
http_status = status.HTTP_500_INTERNAL_SERVER_ERROR
def __init__(self, message=None):
ex_message = 'Internal bulk error'
if message:
ex_message += ': ' + message
super().__init__(ex_message)
class FilterError(BulkError):
http_status = status.HTTP_400_BAD_REQUEST
def __init__(self, reason):
ex_message = f'filter error: {reason}'
super().__init__(ex_message)
class TooManyColumnsRequested(BulkError):
http_status = status.HTTP_400_BAD_REQUEST
def __init__(self, nb_requested_cols):
ex_message = (
f"Too many columns: requested '{nb_requested_cols}',"
f" maximum allowed '{Config.max_columns_return.value}'")
super().__init__(ex_message)
def internal_bulk_exceptions(target):
"""
Decoration to handler exceptions that should be not exposed to outside world. e.g. Pyarrow or Dask exceptions
"""
@wraps(target)
async def async_inner(*args, **kwargs):
try:
return await target(*args, **kwargs)
except ArrowInvalid as e:
get_logger().exception(f"Pyarrow ArrowInvalid when running {target.__name__}")
raise BulkNotProcessable(f"Unable to process bulk - {str(e)}")
except ArrowException:
get_logger().exception(f"Pyarrow exception raised when running {target.__name__}")
raise BulkNotProcessable("Unable to process bulk - Arrow")
except scheduler.KilledWorker:
get_logger().exception(f"Dask worker has been killed when running '{target.__name__}'")
raise InternalBulkError("Out of memory")
except Exception:
get_logger().exception(f"Unexpected exception raised when running '{target.__name__}'")
raise
return async_inner
from logging import Logger
from dask.utils import format_bytes, parse_bytes
from distributed import system
from distributed.deploy.utils import nprocesses_nthreads
from app.conf import Config
class DaskException(Exception):
pass
# Amount of memory Reserved for fastApi server + ProcessPoolExecutors
memory_leeway = parse_bytes("600Mi")
def min_worker_memory_recommended(config: Config):
"""Minimal amount of memory required for a Dask worker to not get bad performances"""
return parse_bytes(config.min_worker_memory.value)
def system_memory():
"""returns the detected memory limit for this system (done by distributed)"""
return system.MEMORY_LIMIT
def available_memory_for_workers():
"""Return amount of RAM available for Dask's workers after withdrawing RAM required by server itself"""
return max(0, (system_memory() - memory_leeway))
def recommended_workers_and_threads():
""" Return the recommended numbers of worker and threads according the cpus available provided by Dask """
return nprocesses_nthreads()
def get_dask_configuration(*, config: Config, logger: Logger):
"""
Return recommended Dask workers configuration
"""
n_workers, threads_per_worker = recommended_workers_and_threads()
available_memory_bytes = available_memory_for_workers()
worker_memory_limit = int(available_memory_bytes / n_workers)
logger.info(
f"Dask client - system.MEMORY_LIMIT: {format_bytes(system_memory())} "
f"- available_memory_bytes: {format_bytes(available_memory_bytes)} "
f"- min_worker_memory_recommended: {format_bytes(min_worker_memory_recommended(config))} "
f"- computed worker_memory_limit: {format_bytes(worker_memory_limit)} for {n_workers} workers"
)
if min_worker_memory_recommended(config) > worker_memory_limit:
n_workers = available_memory_bytes // min_worker_memory_recommended(config)
if not n_workers >= 1:
min_memory = min_worker_memory_recommended(config) + memory_leeway
message = (
f"Not enough memory available to start Dask worker. "
f"Please, consider upgrading container memory to {format_bytes(min_memory)}"
)
logger.error(
f"Dask client - {message} - "
f"n_workers: {n_workers} threads_per_worker: {threads_per_worker}, "
f"available_memory_bytes: {available_memory_bytes} "
)
raise DaskException(message)
worker_memory_limit = available_memory_bytes / n_workers
logger.warning(
f"Dask client - available RAM is too low. Reducing number of workers "
f"to {n_workers} running with {format_bytes(worker_memory_limit)} of RAM"
)
return n_workers, threads_per_worker, worker_memory_limit
# Copyright 2021 Schlumberger
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import time
from contextlib import suppress
from operator import attrgetter
from typing import Dict, Generator, List
from distributed.worker import get_client
import pandas as pd
from .utils import share_items
from ...helper.logger import get_logger
from ...helper.traces import with_trace
from app.persistence.sessions_storage import Session
from ..capture_timings import capture_timings
from .storage_path_builder import add_protocol, record_session_path
class SessionFileMeta:
"""The class extract information about chunks."""
def __init__(self, fs, protocol: str, file_path: str, lazy: bool = True) -> None:
"""
Args:
fs: fsspec filesystem
file_path (str): the parquet chunk file path
lazy (bool, optional): prefetch the metadata file if False, else read at demand. Defaults to True.
"""
self._fs = fs
file_name = os.path.basename(file_path)
start, end, tail = file_name.split('_')
self.start = float(start) # data time support ?
self.end = float(end)
self.time, self.shape, tail = tail.split('.')
self._meta = None
self.path = file_path
self.protocol = protocol
if not lazy:
self._read_meta()
def _read_meta(self):
if not self._meta:
path, _ = os.path.splitext(self.path)
with self._fs.open(path + '.meta') as meta_file:
self._meta = json.load(meta_file)
return self._meta
@property
def columns(self) -> List[str]:
"""Returns the column names"""
return self._read_meta()['columns']
@property
def dtypes(self) -> List[str]:
"""Returns the column dtypes"""
return self._read_meta()['dtypes']
@property
def nb_rows(self) -> int:
"""Returns the number of rows of the chunk"""
return self._read_meta()['nb_rows']
@property
def path_with_protocol(self) -> str:
"""Returns chunk path with protocol"""
return add_protocol(self.path, self.protocol)
@property
def index_hash(self) -> str:
"""Returns the index hash"""
return self._read_meta()['index_hash']
def overlap(self, other: 'SessionFileMeta') -> bool:
"""Returns True if indexes overlap."""
return self.end >= other.start and other.end >= self.start
def has_common_columns(self, other: 'SessionFileMeta') -> bool:
"""Returns True if contains common columns with others."""
return share_items(self.columns, other.columns)
def generate_chunk_filename(dataframe: pd.DataFrame) -> str:
"""Generate a chunk filename composed of information from the given dataframe
{first_index}_{last_index}_{time}.{shape}
The shape is a hash of columns names + columns dtypes
If chunks have same shape, dask can read them together.
Warnings:
- This funtion is not idempotent !
- Do not modify the name without updating the class SessionFileMeta !
Indeed, SessionFileMeta parse information from the chunk filename
- Filenames impacts partitions order in Dask as it order them by 'natural key'
Thats why the start index is in the first position
Raises:
IndexError - if empty dataframe
>>> generate_chunk_filename(pd.DataFrame({'A': range(10), 'B': range(10)}, index=range(10)))
'0_9_1637223437910.526782c41fe12c3249046fedcc45563ef3662250'
>>> generate_chunk_filename(pd.DataFrame({'A': range(10), 'B': range(10)}, index=range(10,20)))
'10_19_1637223490719.526782c41fe12c3249046fedcc45563ef3662250'
>>> generate_chunk_filename(pd.DataFrame({'A': [1], 'B': [1]}, index=[datetime.datetime.now()]))
'1639672097644401000_1639672097644401000_1639668497645.526782c41fe12c3249046fedcc45563ef3662250'
>>> generate_chunk_filename(pd.DataFrame({'A': []}, index=[]))
IndexError: index 0 is out of bounds for axis 0 with size 0
"""
first_idx, last_idx = dataframe.index[0], dataframe.index[-1]
if isinstance(dataframe.index, pd.DatetimeIndex):
first_idx, last_idx = dataframe.index[0].value, dataframe.index[-1].value
shape_str = '_'.join(f'{cn}:{dt}' for cn, dt in dataframe.dtypes.items())
shape = hashlib.sha1(shape_str.encode()).hexdigest()
cur_time = round(time.time() * 1000)
return f'{first_idx}_{last_idx}_{cur_time}.{shape}'
def build_chunk_metadata(dataframe: pd.DataFrame) -> dict:
"""Returns dataframe metadata
Other metadata such as start_index or stop_index are saved into the chunk filename
>>> build_chunk_metadata(pd.DataFrame({'A': [1,2,3], 'B': [4,5,6]}, index=[0,1,2]))
{'columns': ['A', 'B'], 'dtypes': ['int64', 'int64'], 'nb_rows': 3, 'index_hash': 'ab2fa50ae23ce035bad2e77ec5e0be05c2f4b816'}
"""
return {
"columns": list(dataframe.columns),
"dtypes": [str(dt) for dt in dataframe.dtypes],
"nb_rows": len(dataframe.index),
"index_hash": hashlib.sha1(dataframe.index.values).hexdigest()
}
@capture_timings('get_chunks_metadata')
@with_trace('get_chunks_metadata')
async def get_chunks_metadata(filesystem, protocol: str, base_directory: str, session: Session) -> List[SessionFileMeta]:
"""Return metadata objects for a given session"""
session_path = record_session_path(base_directory, session.id, session.recordId)
with suppress(FileNotFoundError):
parquet_files = [f for f in filesystem.ls(session_path) if f.endswith(".parquet")]
futures = get_client().map(lambda f: SessionFileMeta(filesystem, protocol, f, lazy=False) , parquet_files)
return await get_client().gather(futures)
return []
def get_next_chunk_files(
chunks_info
) -> Generator[List[SessionFileMeta], None, None]:
"""Generator which groups session chunk files in lists of files that can be read directly with dask
File can be grouped if they have the same schemas and no overlap between indexes
"""
chunks_info.sort(key=attrgetter('time'))
cache: Dict[str, List[SessionFileMeta]] = {}
columns_in_cache = set() # keep track of colunms present in the cache
for chunk in chunks_info:
if chunk.shape in cache: # if other chunks with same shape
# looking for overlaped chunk
for i, cached_chunk in enumerate(cache[chunk.shape]):
if chunk.overlap(cached_chunk):
if chunk.index_hash == cached_chunk.index_hash:
# if chunks are identical in shape and index just keep the last one
get_logger().info(f"Duplicated chunk skipped : '{chunk.path}'")
cache[chunk.shape].pop(i)
else:
yield cache[chunk.shape]
del cache[chunk.shape]
break
elif not columns_in_cache.isdisjoint(chunk.columns): # else if columns conflicts
conflicting_chunk = next(metas[0] for metas in cache.values()
if chunk.has_common_columns(metas[0]))
yield cache[conflicting_chunk.shape]
columns_in_cache = columns_in_cache.difference(conflicting_chunk.columns)
del cache[conflicting_chunk.shape]
cache.setdefault(chunk.shape, []).append(chunk)
columns_in_cache.update(chunk.columns)
yield from cache.values()
# Copyright 2021 Schlumberger
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions that gathers method to build path for bulk storage
"""
import hashlib
from os.path import join, relpath
from typing import Optional, Tuple
def hash_record_id(record_id: str) -> str:
"""encode the record_id to be a valid path name"""
return hashlib.sha1(record_id.encode()).hexdigest()
def build_base_path(base_directory: str, protocol: Optional[str] = None) -> str:
"""return the base directory, add the protocol if requested"""
return f'{protocol}://{base_directory}' if protocol else base_directory
def add_protocol(path: str, protocol: str) -> str:
"""add protocole to the path"""
prefix = protocol + '://'
if not path.startswith(prefix):
return prefix + path
return path
def remove_protocol(path: str) -> Tuple[str, str]:
"""remove protocol for path if any, return tuple[path, protocol].
If no protocol in path then protocol=''
>>> remove_protocol('s3://path/to/my/file')
('path/to/my/file', 's3')
>>> remove_protocol('path/to/my/file')
('path/to/my/file', '')
"""
if '://' not in path:
return path, ''
sep_idx = path.index('://')
return path[sep_idx + 3:], path[:sep_idx]
def record_path(
base_directory: str, record_id, protocol: Optional[str] = None
) -> str:
"""Return the entity path.
(path where all data relateed to an entity are saved"""
encoded_id = hash_record_id(record_id)
base_path = build_base_path(base_directory, protocol)
return join(base_path, encoded_id)
def record_bulk_path(
base_directory: str, record_id: str, bulk_id: str, protocol: Optional[str] = None
) -> str:
"""Return the path corresponding to the specified bulk."""
entity_path = record_path(base_directory, record_id, protocol)
return join(entity_path, 'bulk', bulk_id, 'data')
def record_session_path(
base_directory: str, session_id: str, record_id: str, protocol: Optional[str] = None
) -> str:
"""Return the path corresponding to the specified session."""
entity_path = record_path(base_directory, record_id, protocol)
return join(entity_path, 'session', session_id, 'data')
def record_relative_path(base_directory: str, record_id: str, path: str) -> str:
"""Returns the path relative to the specified record."""
base_path = record_path(base_directory, record_id)
path, _proto = remove_protocol(path)
return relpath(path, base_path)
def full_path(
base_directory: str, record_id: str, rel_path: str, protocol: Optional[str] = None
) -> str:
"""Returns the full path of a record from a relative path"""
return join(record_path(base_directory, record_id, protocol), rel_path)
from typing import Callable, Union
from enum import Enum
from dask.distributed import Client
import pandas as pd
from dask.utils import funcname
from dask.base import tokenize
from opencensus.trace.span import SpanKind
from opencensus.trace import tracer as open_tracer
from opencensus.trace.samplers import AlwaysOnSampler
from app.helper.traces import create_exporter
from app.conf import Config
from app.helper import traces
from app.context import get_ctx
from opencensus.trace import execution_context
from . import dask_worker_write_bulk as bulk_writer
_EXPORTER = None
......@@ -11,18 +22,113 @@ _EXPORTER = None
def wrap_trace_process(*args, **kwargs):
global _EXPORTER
tracing_headers = kwargs.pop('tracing_headers')
target_func = kwargs.pop('target_func')
span_context = kwargs.pop('span_context')
if not span_context or not target_func:
raise AttributeError("Keyword arguments should contain 'target_func' and 'span_context'")
if not tracing_headers or not target_func:
raise AttributeError("Keyword arguments should contain 'target_func' and 'tracing_headers'")
if _EXPORTER is None:
_EXPORTER = create_exporter(service_name=Config.service_name.value)
_EXPORTER = traces.create_exporter(service_name=Config.service_name.value, config=Config)
span_context = traces.get_trace_propagator().from_headers(tracing_headers)
tracer = open_tracer.Tracer(span_context=span_context,
sampler=AlwaysOnSampler(),
exporter=_EXPORTER)
with tracer.span(name=f"Dask Worker - {target_func.__name__}") as span:
with tracer.span(name=f"Dask Worker - {funcname(target_func)}") as span:
span.span_kind = SpanKind.CLIENT
return target_func(*args, **kwargs)
def _create_func_key(func, *args, **kwargs):
"""
Inspired by Dask code, it returns a hashed key based on function name and given arguments
"""
return funcname(func) + "-" + tokenize(func, kwargs, *args)
def submit_with_trace(dask_client: Client, target_func: Callable, *args, **kwargs):
"""Submit given target_func to Distributed Dask workers and add tracing required stuff
Note: 'dask_task_key' is manually created to easy reading of Dask's running tasks: it will display
the effective targeted function instead of 'wrap_trace_process' used to enable tracing into Dask workers.
"""
tracing_headers = traces.get_trace_propagator().to_headers(get_ctx().tracer.span_context)
kwargs['tracing_headers'] = tracing_headers
kwargs['target_func'] = target_func
dask_task_key = _create_func_key(target_func, *args, **kwargs)
return dask_client.submit(wrap_trace_process, *args, key=dask_task_key, **kwargs)
def map_with_trace(dask_client: Client, target_func: Callable, *args, **kwargs):
"""Submit given target_func to Distributed Dask workers and add tracing required stuff
Note: 'dask_task_key' is manually created to easy reading of Dask's running tasks: it will display
the effective targeted function instead of 'wrap_trace_process' used to enable tracing into Dask workers.
"""
tracing_headers = traces.get_trace_propagator().to_headers(get_ctx().tracer.span_context)
kwargs['tracing_headers'] = tracing_headers
kwargs['target_func'] = target_func
dask_task_key = _create_func_key(target_func, *args, **kwargs)
return dask_client.map(wrap_trace_process, *args, key=dask_task_key, **kwargs)
class TracingMode(Enum):
""" Allow to determine which mode of adding attributes on tracing span is needed. """
CURRENT_SPAN = 1
ROOT_SPAN = 2
def _add_trace_attributes(attributes: dict, tracing_mode: TracingMode):
"""
If tracer exists, add custom key:value as attributes on root or current span according value of 'tracing_mode'.
NOTE: if called by a Dask worker, the parent span is the one created by `wrap_trace_process` function above.
"""
opencensus_tracer = execution_context.get_opencensus_tracer()
if opencensus_tracer is None or not hasattr(opencensus_tracer, 'tracer'):
return
span = None
if tracing_mode == TracingMode.CURRENT_SPAN:
span = opencensus_tracer.current_span()
elif tracing_mode == TracingMode.ROOT_SPAN:
existing_spans = opencensus_tracer.tracer.list_collected_spans()
span = existing_spans[0] if existing_spans else None
if not span:
return
for k, v in attributes.items():
span.add_attribute(attribute_key=k,
attribute_value=v)
def trace_attributes_root_span(attributes):
""" Add attributes to root tracing span """
_add_trace_attributes(attributes, TracingMode.ROOT_SPAN)
def trace_attributes_current_span(attributes):
""" Add attributes to current tracing span """
_add_trace_attributes(attributes, TracingMode.CURRENT_SPAN)
def trace_dataframe_attributes(df: Union[pd.DataFrame, bulk_writer.DataframeBasicDescribe]):
"""
Add dataframe shape into current tracing span if tracer exists
"""
if type(df) is pd.DataFrame:
df = bulk_writer.basic_describe(df)
trace_attributes_current_span({
"df rows count": df.row_count,
"df columns count": df.column_count,
"df index start": df.index_start,
"df index end": df.index_end,
"df index type": df.index_type,
})
......@@ -12,29 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from itertools import zip_longest
from logging import INFO
from app.helper.logger import get_logger
from app.utils import capture_timings
from typing import List, Optional
import dask.dataframe as dd
import pandas as pd
import pyarrow.parquet as pa
from ...helper.logger import get_logger
from ..capture_timings import capture_timings
WDMS_INDEX_NAME = '_wdms_index_'
def worker_make_log_captured_timing_handler(level=INFO):
"""log captured timing from the worker subprocess (no access to context)"""
def log_captured_timing(tag, wall, cpu):
logger = get_logger()
if logger:
logger.log(level, f"Timing of {tag}, wall={wall:.5f}s, cpu={cpu:.5f}s")
return log_captured_timing
worker_capture_timing_handlers = [worker_make_log_captured_timing_handler(INFO)]
##
def share_items(seq1, seq2):
"""Returns True if seq1 contains common items with seq2."""
......@@ -47,31 +51,6 @@ def by_pairs(iterable):
return zip_longest(*[iter(iterable)] * 2, fillvalue=None)
class SessionFileMeta:
def __init__(self, fs, file_path: str) -> None:
self._fs = fs
file_name = os.path.basename(file_path)
start, end, tail = file_name.split('_')
self.start = float(start) # data time support ?
self.end = float(end)
self.time, self.shape, tail = tail.split('.')
self.columns = self._get_columns(file_path) # TODO lazy load
self.path = file_path
def _get_columns(self, file_path):
path, _ = os.path.splitext(file_path)
with self._fs.open(path + '.meta') as meta_file:
return json.load(meta_file)['columns']
def overlap(self, other: 'SessionFileMeta'):
"""Returns True if indexes overlap."""
return self.end >= other.start and other.end >= self.start
def has_common_columns(self, other):
"""Returns True if contains common columns with others."""
return share_items(self.columns, other.columns)
@capture_timings("set_index", handlers=worker_capture_timing_handlers)
def set_index(ddf: dd.DataFrame):
"""Set index of the dask dataFrame only if needed."""
......@@ -80,18 +59,46 @@ def set_index(ddf: dd.DataFrame):
return ddf
@capture_timings("join_dataframes", handlers=worker_capture_timing_handlers)
def join_dataframes(dfs: List[dd.DataFrame]):
if len(dfs) > 1:
return dfs[0].join(dfs[1:], how='outer')
return dfs[0] if dfs else None
def rename_index(dataframe: pd.DataFrame, name):
"""Rename the dataframe index"""
dataframe.index.name = name
return dataframe
@capture_timings("do_merge", handlers=worker_capture_timing_handlers)
def do_merge(df1: dd.DataFrame, df2: dd.DataFrame):
def do_merge(df1: dd.DataFrame, df2: Optional[dd.DataFrame]):
"""Combine the 2 dask dataframe. Updates df1 with df2 values if overlap."""
if df2 is None:
return df1
df1 = set_index(df1)
df2 = set_index(df2)
df1 = df1.map_partitions(rename_index, WDMS_INDEX_NAME)
df2 = df2.map_partitions(rename_index, WDMS_INDEX_NAME)
if share_items(df1.columns, df2.columns):
ddf = df2.combine_first(df1)
else:
ddf = df1.join(df2, how='outer') # join seems faster when there no columns in common
return df2.combine_first(df1)
return df1.join(df2, how='outer') # join seems faster when there no columns in common
return ddf
@capture_timings("get_num_rows", handlers=worker_capture_timing_handlers)
def get_num_rows(dataset: pa.ParquetDataset) -> int:
"""Returns the number of rows from a pyarrow ParquetDataset"""
metadata = dataset.common_metadata
if metadata and metadata.num_rows > 0:
return metadata.num_rows
return sum((piece.get_metadata().num_rows for piece in dataset.pieces))
@capture_timings("index_union", handlers=worker_capture_timing_handlers)
def index_union(idx1: pd.Index, idx2: Optional[pd.Index]):
"""Union of two Index object (check pd.Index.union doc string for more details)"""
return idx1.union(idx2) if idx2 is not None else idx1
......@@ -13,11 +13,12 @@
# limitations under the License.
import io
from typing import Tuple
import pandas as pd
from osdu.core.api.storage.blob_storage_base import BlobStorageBase
from app.utils import Context
from app.context import Context
from .blob_storage import (
BlobBulk,
......@@ -25,18 +26,19 @@ from .blob_storage import (
create_and_write_blob,
read_blob,
)
from .bulk_id import BulkId
from .mime_types import MimeTypes
from .bulk_id import new_bulk_id
from .dask.errors import internal_bulk_exceptions
from .mime_types import MimeTypes, MimeType
from .tenant_provider import resolve_tenant
from ..helper.traces import with_trace
async def create_and_store_dataframe(ctx: Context, df: pd.DataFrame) -> str:
"""Store bulk on a blob storage"""
new_bulk_id = BulkId.new_bulk_id()
bulk_id = new_bulk_id()
tenant = await resolve_tenant(ctx.partition_id)
async with create_and_write_blob(
df, file_exporter=BlobFileExporters.PARQUET, blob_id=new_bulk_id
df, file_exporter=BlobFileExporters.PARQUET, blob_id=bulk_id
) as bulkblob:
storage: BlobStorageBase = await ctx.app_injector.get(BlobStorageBase)
await storage.upload(
......@@ -49,8 +51,7 @@ async def create_and_store_dataframe(ctx: Context, df: pd.DataFrame) -> str:
return bulkblob.id
@with_trace('get_dataframe')
async def get_dataframe(ctx: Context, bulk_id: str) -> pd.DataFrame:
async def download_bulk(ctx: Context, bulk_id: str) -> Tuple[bytes, MimeType]:
""" fetch bulk from a blob storage, provide column major """
tenant = await resolve_tenant(ctx.partition_id)
storage: BlobStorageBase = await ctx.app_injector.get(BlobStorageBase)
......@@ -59,10 +60,17 @@ async def get_dataframe(ctx: Context, bulk_id: str) -> pd.DataFrame:
# for now use fix parquet format saving one call
# meta_data = await storage.download_metadata(tenant.project_id, tenant.bucket_name, bulk_id)
# content_type = meta_data.metadata["content_type"]
return bytes_data, MimeTypes.PARQUET
@internal_bulk_exceptions
@with_trace('get_dataframe')
async def get_dataframe(ctx: Context, bulk_id: str) -> pd.DataFrame:
bytes_data, content_type = await download_bulk(ctx, bulk_id)
blob = BlobBulk(
id=bulk_id,
data=io.BytesIO(bytes_data),
content_type=MimeTypes.PARQUET.type,
content_type=content_type.type,
)
data_frame = await read_blob(blob)
return data_frame
......@@ -13,18 +13,17 @@
# limitations under the License.
import asyncio
from functools import partial
from io import BytesIO
from typing import Union, AnyStr, IO, Optional, List, Dict
from typing import Union, Optional, List, Dict
from pathlib import Path
import numpy as np
import pandas as pd
from pydantic import BaseModel
from pandas import DataFrame as DataframeClass
from .json_orient import JSONOrient
from .mime_types import MimeTypes
from app.utils import get_pool_executor
from .mime_types import MimeTypes, MimeType
from app.pool_executor import get_pool_executor
from ..helper.traces import with_trace
......@@ -34,9 +33,7 @@ class DataframeSerializerSync:
then provide unified the way to handle various topics float/double precision, compression etc...
"""
# todo may be unified with the work from storage.blob_storage
SupportedFormat = [MimeTypes.JSON] # , MimeTypes.MSGPACK]
SupportedFormat = [MimeTypes.JSON, MimeTypes.PARQUET]
""" these are supported format through wellbore ddms APIs """
@classmethod
......@@ -59,21 +56,31 @@ class DataframeSerializerSync:
@classmethod
def to_json(cls,
df: DataframeClass,
df: pd.DataFrame,
orient: Union[str, JSONOrient] = JSONOrient.split,
path_or_buf: Optional[Union[str, Path, IO[AnyStr]]] = None) -> Optional[str]:
**kwargs) -> Optional[str]:
"""
:param df: dataframe to dump
:param orient: format for Json, default is split
:param path_or_buf: File path or object. If not specified, the result is returned as a string.
:param kwargs: keyword arguments will be forwarded to pandas.to_json()
:return: None or json string of path_or_buf is None
"""
orient = JSONOrient.get(orient)
return df.fillna("NaN").to_json(path_or_buf=path_or_buf, orient=orient.value)
return df.fillna("NaN").to_json(orient=orient.value, **kwargs)
@classmethod
def read_parquet(cls, data) -> 'DataframeSerializerAsync.DataframeClass':
def to_parquet(cls, df: pd.DataFrame, path_or_buf=None, *, storage_options=None):
"""
:param df: dataframe to dump
:param path_or_buf: str or file-like object, default None, see Pandas.Dataframe.to_parquet
:param storage_options: storage_options, default None
:return: None or buffer
"""
return df.to_parquet(path_or_buf, index=True, engine='pyarrow', storage_options=storage_options)
@classmethod
def read_parquet(cls, data) -> pd.DataFrame:
"""
:param data: bytes, path object or file-like object
:return: dataframe
......@@ -85,38 +92,97 @@ class DataframeSerializerSync:
return pd.read_parquet(data)
@classmethod
def read_json(cls, data, orient: Union[str, JSONOrient], convert_axes: Optional[bool] = None) -> 'DataframeSerializerAsync.DataframeClass':
def read_json(cls, data, orient: Union[str, JSONOrient]) -> pd.DataFrame:
"""
:param data: bytes str content (valid JSON str), path object or file-like object
:param data: bytes str content (valid JSON str), path object or file-like object. It won't convert axes. In case
of orient='columns' since the indexes type is lost, it will still try to coerce index into 'int'
then 'float' then try convert to date time. 'Columns' will remain as string type.
For orient 'split' no convert at all.
:param orient:
:return: dataframe
"""
if isinstance(data, bytes):
data = BytesIO(data)
orient = JSONOrient.get(orient)
return pd.read_json(path_or_buf=data, orient=orient.value, convert_axes=convert_axes).replace("NaN", np.NaN)
df = pd.read_json(
path_or_buf=data, orient=orient.value, convert_axes=False
).replace("NaN", np.NaN)
# this is a conner case, orient 'columns' implies to have all columns and index values to be passed as string
# in JSON content.
# In that case, their original types are lost. Since parameter 'convert_axes' is set to False, pandas won't
# try to infer the types of the columns and index.
# Regarding columns, it remains OK since WDMS enforces them to be string. Then, using orient 'columns' will cast
# them to string 'by design'.
# For the index values it's problematic. In main cases, those are integer values and it matters to have them
# back to the original type if possible.
#
# Here's the tradeoff to handle the case orient='columns':
# - no convert on columns, so remains as string type
# - try to coerce index to 'float64' or 'int64'
#
# This is similar to what is done in Pandas but only for index:
# see https://github.com/pandas-dev/pandas/blob/master/pandas/io/json/_json.py#L916
if orient == JSONOrient.columns:
for dtype in ['int64', 'float64']:
try:
# try to coerce index type as int then float
df.index = df.index.astype(dtype)
return df
except (TypeError, ValueError, OverflowError):
continue
return df
@classmethod
def load(cls, file_like_data,
content_type: MimeType,
orient: Optional[Union[str, JSONOrient]] = None) -> pd.DataFrame:
"""
deserialized input data as pandas dataframe
:param file_like_data: input ipc raw bytes wrapped (file-like obj)
:param content_type: content type value (supports json and parquet)
:param orient: in content json, orient must be provided.
:return: pandas dataframe
:throw: ValueError
"""
if content_type == MimeTypes.JSON:
return cls.read_json(file_like_data, orient=orient)
elif content_type == MimeTypes.PARQUET:
return cls.read_parquet(file_like_data)
else:
raise ValueError(f"unsupported content_type {content_type}")
class DataframeSerializerAsync:
def __init__(self, pool_executor=get_pool_executor()):
self.executor = pool_executor
@with_trace("Parquet bulk serialization")
async def to_parquet(self, df: pd.DataFrame, *, storage_options=None) -> pd.DataFrame:
func = partial(DataframeSerializerSync.to_parquet, df, storage_options=storage_options)
return await asyncio.get_event_loop().run_in_executor(self.executor, func)
@with_trace("JSON bulk serialization")
async def to_json(self,
df: DataframeClass,
df: pd.DataFrame,
orient: Union[str, JSONOrient] = JSONOrient.split,
path_or_buf: Optional[Union[str, Path, IO[AnyStr]]] = None) -> Optional[str]:
return await asyncio.get_event_loop().run_in_executor(
self.executor, DataframeSerializerSync.to_json, df, orient, path_or_buf
)
*args, **kwargs) -> Optional[str]:
func = partial(DataframeSerializerSync.to_json, df, orient, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(self.executor, func)
@with_trace("Parquet bulk deserialization")
async def read_parquet(self, data) -> DataframeClass:
async def read_parquet(self, data) -> pd.DataFrame:
return await asyncio.get_event_loop().run_in_executor(
self.executor, DataframeSerializerSync.read_parquet, data
)
@with_trace("Parquet JSON deserialization")
async def read_json(self, data, orient: Union[str, JSONOrient], convert_axes: Optional[bool] = None) -> DataframeClass:
async def read_json(self, data, orient: Union[str, JSONOrient]) -> pd.DataFrame:
return await asyncio.get_event_loop().run_in_executor(
self.executor, DataframeSerializerSync.read_json, data, orient, convert_axes
self.executor, DataframeSerializerSync.read_json, data, orient
)
from typing import Tuple, Callable, Iterable, List
import re
import pandas as pd
from .dask.utils import WDMS_INDEX_NAME
from .dask.errors import BulkNotProcessable
from app.conf import Config
ValidationResult = Tuple[bool, str] # Tuple (is_dataframe_valid, failure_reason)
ValidationSuccess = (True, '')
DataFrameValidationFunc = Callable[[pd.DataFrame], ValidationResult]
def assert_df_validate(dataframe: pd.DataFrame,
validation_funcs: List[DataFrameValidationFunc]):
""" call one or more validation function and throw BulkNotProcessable in case of invalid, run all validation before
returning """
if not validation_funcs:
return
all_validity, all_reasons = zip(*[fn(dataframe) for fn in validation_funcs])
if not all(all_validity):
# raise exception with all invalid reasons
raise BulkNotProcessable(message=",".join([msg for ok, msg in zip(all_validity, all_reasons) if not ok]))
# the following functions are stateless and without side-effect so can be easily used in parallel/cross process context
def no_validation(_) -> ValidationResult:
"""
Always validate the given dataframe without error/warning
return True, ''
"""
return ValidationSuccess
def auto_cast_columns_to_string(df: pd.DataFrame) -> ValidationResult:
"""
If given dataframe contains columns name which is not a string, cast it
return always returns validation success
"""
df.columns = df.columns.astype(str)
return ValidationSuccess
def columns_type_must_be_string(df: pd.DataFrame) -> ValidationResult:
""" Ensure given dataframe contains columns name as string only as described by WellLog schemas """
if all((type(t) is str for t in df.columns)):
return ValidationSuccess
return False, 'All columns type should be string'
def validate_index(df: pd.DataFrame) -> ValidationResult:
""" Ensure index """
if len(df.index) == 0:
return False, "Empty data"
if not df.index.is_numeric() and not isinstance(df.index, pd.DatetimeIndex):
return False, "Index should be numeric or datetime"
if not df.index.is_unique:
return False, "Duplicated index found"
return ValidationSuccess
def validate_number_of_columns(df: pd.DataFrame) -> ValidationResult:
""" Verify max number of columns """
if len(df.columns) > Config.max_columns_per_chunk_write.value:
return False, f"Too many columns : maximum allowed '{Config.max_columns_per_chunk_write.value}'"
return ValidationSuccess
PandasReservedIndexColRegexp = re.compile(r'__index_level_\d+__')
def is_reserved_column_name(name: str) -> bool:
"""Return True if the name is a reserved column name by Pandas/Dask with PyArrow"""
return (PandasReservedIndexColRegexp.match(name)
or name == '__null_dask_index__'
or name == WDMS_INDEX_NAME)
def any_reserved_column_name(names: Iterable[str]) -> bool:
"""
There are reserved name for columns which are internally used by Pandas/Dask with PyArrow to save the index.
Save a df containing reserved name as regular columns lead to inability to read parquet file then.
At this stage, columns used as index are already marked as index and it's not considered as columns by Pandas.
return: True is any column uses a reserved name
"""
return any(is_reserved_column_name(name) for name in names if type(name) is str)
def columns_not_in_reserved_names(df: pd.DataFrame) -> ValidationResult:
if any_reserved_column_name(df.columns):
return False, 'Invalid column name'
return ValidationSuccess
......@@ -23,6 +23,8 @@ class MimeType(NamedTuple):
alternative_types: List[str] = []
def match(self, str_value: str) -> bool:
if not str_value:
return False
normalized_value = str_value.lower()
return any(
(
......@@ -55,8 +57,6 @@ class MimeTypes:
JSON = MimeType(type="application/json", extension=".json")
CSV = MimeType(type="text/csv", extension=".csv")
MSGPACK = MimeType(
type="application/x-msgpack",
extension=".msgpack",
......
import tempfile
from os import path, makedirs
def _setup_temp_dir() -> str:
tmpdir = tempfile.gettempdir()
if not tmpdir.endswith('wdmsosdu'):
tmpdir = path.join(tmpdir, 'wdmsosdu')
makedirs(tmpdir, exist_ok=True)
tempfile.tempdir = tmpdir
return tmpdir
_TEMP_DIR = _setup_temp_dir()
def get_temp_dir():
return _TEMP_DIR
......@@ -29,7 +29,7 @@ async def resolve_tenant(data_partition_id: str) -> Tenant:
return Tenant(
data_partition_id=data_partition_id,
project_id='',
bucket_name='wdms-osdu'
bucket_name=Config.az_bulk_container
)
if Config.cloud_provider.value == 'ibm':
......
......@@ -17,7 +17,6 @@ import odes_search
import odes_storage
from odes_search.api_client import AsyncSearchApi
from odes_storage.api_client import AsyncRecordsApi
from app.conf import Config
from dataclasses import dataclass
from typing import Optional
......@@ -40,26 +39,26 @@ class Limits:
keepalive_expiry: Optional[float] = 5.0
def make_search_client(host) -> SearchServiceClient:
def make_search_client(*, host, timeout, max_connections=None, max_keepalive_connections=None) -> SearchServiceClient:
search_client = odes_search.ApiClient(
host=host,
timeout=Config.de_client_config_timeout.value,
timeout=timeout,
limits=Limits(
max_connections=Config.de_client_config_max_connection.value or None,
max_keepalive_connections=Config.de_client_config_max_keepalive.value or None)
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections)
)
search_client.add_middleware(middleware=client_middleware)
search_client.add_middleware(middleware=backoff_middleware)
return odes_search.AsyncApis(search_client).search_api
def make_storage_record_client(host) -> StorageRecordServiceClient:
def make_storage_record_client(*, host, timeout, max_connections=None, max_keepalive_connections=None) -> StorageRecordServiceClient:
storage_client = odes_storage.ApiClient(
host=host,
timeout=Config.de_client_config_timeout.value,
timeout=timeout,
limits=Limits(
max_connections=Config.de_client_config_max_connection.value or None,
max_keepalive_connections=Config.de_client_config_max_keepalive.value or None)
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections)
)
storage_client.add_middleware(middleware=client_middleware)
storage_client.add_middleware(middleware=backoff_middleware)
......
......@@ -4,15 +4,25 @@ from app.conf import Config
from httpx import (
RemoteProtocolError,
TimeoutException) # => ReadTimeout, WriteTimeout, ConnectTimeout, PoolTimeout
from odes_storage.exceptions import ResponseHandlingException
_exceptions_type_to_retry = (RemoteProtocolError, TimeoutException, ResponseHandlingException)
def backoff_policy(on_backoff_handlers=None):
return backoff.on_exception(backoff.expo,
(RemoteProtocolError, TimeoutException, ResponseHandlingException),
"""
Returns: a retry decorator.
Triggered in case if raised exception is: RemoteProtocolError, TimeoutException, ResponseHandlingException.
It will retry a maximum number of `Config.de_client_backoff_max_tries.value`.
'base', 'factor' and 'max_value' are kwargs for `backoff.expo` generator function.
"""
return backoff.on_exception(backoff.expo, # it will generate [1, 2, 4, 5, .. , 5]
_exceptions_type_to_retry,
max_tries=Config.de_client_backoff_max_tries.value,
on_backoff=on_backoff_handlers,
base=0.5,
base=2,
factor=1,
max_value=Config.de_client_backoff_max_wait.value)
......@@ -16,8 +16,9 @@ from opencensus.trace.span import SpanKind
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from app import conf
from app.utils import Context
from app.context import Context
from app.helper import utils, traces
from app.helper.logger import get_logger
from .backoff_policy import backoff_policy
from sys import exc_info
from traceback import format_exception
......@@ -45,11 +46,9 @@ def _before_tracing_attributes(ctx, request):
def backoff_handler_log_it(details):
ctx = Context.current()
exception_type, raised_exec, tb = exc_info()
s_stack = format_exception(exception_type, raised_exec, tb)
ctx.logger.exception(f"Backoff callback, tries={details['tries']}: {raised_exec}. Stack = {s_stack}")
get_logger().exception(f"Backoff callback, tries={details['tries']}: {raised_exec}. Stack = {s_stack}")
@backoff_policy(backoff_handler_log_it)
......@@ -68,13 +67,15 @@ async def client_middleware(request, call_next):
tracing_headers = traces.get_trace_propagator().to_headers(span.context_tracer.span_context)
request.headers.update(tracing_headers)
ctx.logger.debug(f"client_middleware - url: {request.url} - tracing_headers: {tracing_headers}")
get_logger().debug(f"client_middleware - url: {request.url} - tracing_headers: {tracing_headers}")
request.headers[conf.AUTHORIZATION_HEADER_NAME] = f'Bearer {ctx.auth}'
if ctx.correlation_id:
request.headers[conf.CORRELATION_ID_HEADER_NAME] = ctx.correlation_id
if ctx.app_key:
request.headers[conf.APP_KEY_HEADER_NAME] = ctx.app_key
if ctx.x_user_id:
request.headers[conf.X_USER_ID_HEADER_NAME] = ctx.x_user_id
response = None
try:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from app.clients import SearchServiceClient
from app.utils import Context
from app.context import Context
async def get_search_service(ctx: Context) -> SearchServiceClient:
......
......@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import uuid
from asyncio import gather, iscoroutinefunction
from typing import List
import httpx
from app.model import model_utils
from fastapi import HTTPException, status
from odes_storage import UnexpectedResponse
from odes_storage.models import (CreateUpdateRecordsResponse, Record,
RecordVersions)
from osdu.core.api.storage.blob_storage_base import BlobStorageBase
......@@ -64,8 +68,8 @@ class StorageRecordServiceBlobStorage:
@staticmethod
def _get_record_folder(id: str, data_partition: str):
encoded_id = hash(id)
def _get_record_folder(record_id: str, data_partition: str):
encoded_id = hashlib.md5(record_id.encode()).hexdigest()
folder = f'{data_partition or "global"}_r_{encoded_id}'
return folder
......@@ -114,7 +118,7 @@ class StorageRecordServiceBlobStorage:
# manual for now
return CreateUpdateRecordsResponse(recordCount=len(record_list),
recordIds=[record.id for record in record_list],
recordIdVersions=[record.version for record in record_list],
recordIdVersions=[f"{record.id}:{record.version}" for record in record_list],
skipped_record_ids=[])
async def get_record_version(self,
......@@ -133,7 +137,13 @@ class StorageRecordServiceBlobStorage:
object_name)
return Record.parse_raw(bin_data)
except (FileNotFoundError, ResourceNotFoundException):
raise HTTPException(status_code=404, detail="Item not found")
raise UnexpectedResponse(
status_code=404,
reason_phrase="Item not found",
# not sure what to put here at this time
content="".encode(encoding="utf-8"),
headers=httpx.Headers(),
)
async def get_all_record_versions(self,
id: str,
......@@ -166,7 +176,24 @@ class StorageRecordServiceBlobStorage:
Tenant(project_id=self._project, bucket_name=self._container, data_partition_id=data_partition_id),
object_name)
except FileNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Item not found")
raise UnexpectedResponse(
status_code=404,
reason_phrase="Item not found",
# not sure what to put here at this time
content="".encode(encoding="utf-8"),
headers=httpx.Headers(),
)
async def delete_records(self, data_partition_id: str, request_body: List[str]) -> None:
await gather(*[
self._storage.delete(
Tenant(project_id=self._project, bucket_name=self._container, data_partition_id=data_partition_id),
record_id
)
for record_id in request_body
], return_exceptions=False) # return_exceptions False means will throw if a single error occurs
return httpx.Response(status_code=204)
async def get_schema(self, kind, data_partition_id=None, appkey=None, token=None, *args, **kwargs):
raise NotImplementedError('StorageServiceBlobStorage.get_schema')
......@@ -13,7 +13,7 @@
# limitations under the License.
from app.clients import StorageRecordServiceClient
from app.utils import Context
from app.context import Context
async def get_storage_record_service(ctx: Context) -> StorageRecordServiceClient:
......