Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • osdu/platform/domain-data-mgmt-services/wellbore/wellbore-domain-services
  • Vkamani/vkamani-wellbore-domain-services
  • Yan_Sushchynski/wellbore-domain-services-comm-impl
3 results
Show changes
Commits on Source (34)
Showing
with 255 additions and 89 deletions
......@@ -240,10 +240,14 @@ osdu-gcp-test-python:
osdu-gcp-dev2-test-python:
extends: .osdu-gcp-dev2-variables
stage: integration
image: gcr.io/google.com/cloudsdktool/cloud-sdk
image: python:3.8
needs: ["osdu-gcp-dev2-deploy-deployment"]
script:
- apt-get install -y python3-venv
- bash
- curl https://sdk.cloud.google.com > install.sh
- bash install.sh --disable-prompts
- source /root/google-cloud-sdk/completion.bash.inc
- source /root/google-cloud-sdk/path.bash.inc
- python3 -m venv env
- source env/bin/activate
- pip install --upgrade pip
......
......@@ -20,6 +20,8 @@ from .mime_types import MimeTypes
from .tenant_provider import resolve_tenant
from .exceptions import UnknownChannelsException, InvalidBulkException, NoBulkException, NoDataException, RecordNotFoundException
from .consistency_checks import ConsistencyException, DataConsistencyChecks
from .dask.client import DaskClient
from .dask.localcluster import DaskException
# TMP: this should probably not be exposed outside of the bulk_persistence package
from .temp_dir import get_temp_dir
......@@ -15,13 +15,13 @@
This module groups function related to bulk catalog.
A catalog contains metadata of the chunks
"""
import asyncio
import functools
import json
from contextlib import suppress
from dataclasses import dataclass
from typing import Dict, Iterable, List, NamedTuple, Optional, Set
from dask.distributed import get_client
from .traces import submit_with_trace, trace_attributes_root_span
from app.helper.traces import with_trace
from app.utils import capture_timings
......@@ -179,32 +179,26 @@ CATALOG_FILE_NAME = 'bulk_catalog.json'
@capture_timings('save_bulk_catalog', handlers=worker_capture_timing_handlers)
@with_trace('save_bulk_catalog')
def save_bulk_catalog(filesystem, folder_path: str, catalog: BulkCatalog) -> None:
async def save_bulk_catalog(filesystem, folder_path: str, catalog: BulkCatalog) -> None:
"""save a bulk catalog to a json file in the given folder path"""
folder_path, _ = remove_protocol(folder_path)
meta_path = join(folder_path, CATALOG_FILE_NAME)
with filesystem.open(meta_path, 'w') as outfile:
data = json.dumps(catalog.as_dict(), indent=0)
_func = functools.partial(json.dumps, catalog.as_dict(), indent=0)
data = await asyncio.get_running_loop().run_in_executor(None, _func)
outfile.write(data)
# json.dump(catalog.as_dict(), outfile) # don't know why json.dump is slower (local windows)
@capture_timings('load_bulk_catalog', handlers=worker_capture_timing_handlers)
@with_trace('load_bulk_catalog')
def load_bulk_catalog(filesystem, folder_path: str) -> Optional[BulkCatalog]:
async def load_bulk_catalog(filesystem, folder_path: str) -> Optional[BulkCatalog]:
"""load a bulk catalog from a json file in the given folder path"""
folder_path, _ = remove_protocol(folder_path)
meta_path = join(folder_path, CATALOG_FILE_NAME)
with suppress(FileNotFoundError):
with filesystem.open(meta_path) as json_file:
data = json.load(json_file)
data = await asyncio.get_running_loop().run_in_executor(None, json.load, json_file)
return BulkCatalog.from_dict(data)
return None
async def async_load_bulk_catalog(filesystem, folder_path: str) -> BulkCatalog:
return await submit_with_trace(get_client(), load_bulk_catalog, filesystem, folder_path)
async def async_save_bulk_catalog(filesystem, folder_path: str, catalog: BulkCatalog) -> None:
return await submit_with_trace(get_client(), save_bulk_catalog, filesystem, folder_path, catalog)
# Copyright 2021 Schlumberger
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dask
from dask.utils import format_bytes
from dask.distributed import Client as DaskDistributedClient
from distributed import LocalCluster
from app.conf import Config
from .localcluster import get_dask_configuration
HOUR = 3600 # in seconds
class DaskClient:
# singleton of DaskDistributedClient class
client: DaskDistributedClient = None
# Ensure access to critical section is done for only one coroutine
lock_client: asyncio.Lock = None
@staticmethod
async def create() -> DaskDistributedClient:
if not DaskClient.lock_client:
DaskClient.lock_client = asyncio.Lock()
if not DaskClient.client:
async with DaskClient.lock_client:
if not DaskClient.client:
from app.helper.logger import get_logger
logger = get_logger()
logger.info(f"Dask client initialization started...")
n_workers, threads_per_worker, worker_memory_limit = get_dask_configuration(config=Config, logger=logger)
logger.info(f"Dask client worker configuration: {n_workers} workers running with "
f"{format_bytes(worker_memory_limit)} of RAM and {threads_per_worker} threads each")
# Ensure memory used by workers is freed regularly despite memory leak
dask.config.set({'distributed.worker.lifetime.duration': HOUR * 24})
dask.config.set({'distributed.worker.lifetime.stagger': HOUR * 1})
dask.config.set({'distributed.worker.lifetime.restart': True})
logger.info(f"Dask cluster configuration - "
f"worker lifetime: {dask.config.get('distributed.worker.lifetime.duration')}s. "
f"stagger: {dask.config.get('distributed.worker.lifetime.stagger')}s.")
cluster = await LocalCluster(
asynchronous=True,
processes=True,
threads_per_worker=threads_per_worker,
n_workers=n_workers,
memory_limit=worker_memory_limit,
dashboard_address=None
)
# A worker could be killed when executing a task if lifetime duration elapsed,
# "cluster.adapt(min=N, max=N)" ensure the respawn of workers if it happens
cluster.adapt(minimum=n_workers, maximum=n_workers)
DaskClient.client = await DaskDistributedClient(cluster, asynchronous=True)
get_logger().info(f"Dask client initialized : {DaskClient.client}")
return DaskClient.client
@staticmethod
async def close():
if not DaskClient.lock_client:
return
async with DaskClient.lock_client:
if DaskClient.client:
# closing the cluster (started independently from the client)
cluster = await DaskClient.client.cluster
await cluster.close()
await DaskClient.client.close() # or shutdown
DaskClient.client = None
......@@ -27,13 +27,13 @@ from osdu.core.api.storage.dask_storage_parameters import DaskStorageParameters
from app.helper.logger import get_logger
from app.helper.traces import with_trace
from app.persistence.sessions_storage import Session
from app.utils import DaskClient, capture_timings
from app.utils import capture_timings
from app.conf import Config
from .client import DaskClient
from .dask_worker_plugin import DaskWorkerPlugin
from .errors import BulkRecordNotFound, BulkNotProcessable, internal_bulk_exceptions
from .traces import map_with_trace, submit_with_trace, trace_attributes_root_span
from .traces import map_with_trace, submit_with_trace, trace_attributes_root_span, trace_attributes_current_span
from .utils import (WDMS_INDEX_NAME, by_pairs, do_merge, join_dataframes, worker_capture_timing_handlers,
get_num_rows, set_index, index_union)
from ..dataframe_validators import is_reserved_column_name, DataFrameValidationFunc
......@@ -42,13 +42,14 @@ from . import storage_path_builder as pathBuilder
from . import session_file_meta as session_meta
from ..bulk_id import new_bulk_id
from .bulk_catalog import (BulkCatalog, ChunkGroup,
async_load_bulk_catalog,
async_save_bulk_catalog)
load_bulk_catalog,
save_bulk_catalog)
from ..mime_types import MimeType
from .dask_data_ipc import DaskNativeDataIPC, DaskLocalFileDataIPC
from . import dask_worker_write_bulk as bulk_writer
from ..consistency_checks import DataConsistencyChecks
def read_with_dask(path: Union[str, List[str]], **kwargs) -> dd.DataFrame:
"""call dask.dataframe.read_parquet with default parameters
Dask read_parquet parameters:
......@@ -180,6 +181,11 @@ class DaskBulkStorage:
if index_df:
dfs.append(index_df)
trace_attributes_current_span({
'parquet-files-to-load-count': len(files_to_load),
'df-to-merge-count': len(dfs)
})
if not dfs:
raise RuntimeError("cannot find requested columns")
......@@ -250,7 +256,7 @@ class DaskBulkStorage:
@with_trace('get_bulk_catalog')
async def get_bulk_catalog(self, record_id: str, bulk_id: str, generate_if_not_exists=True) -> BulkCatalog:
bulk_path = pathBuilder.record_bulk_path(self.base_directory, record_id, bulk_id)
catalog = await async_load_bulk_catalog(self._fs, bulk_path)
catalog = await load_bulk_catalog(self._fs, bulk_path)
if catalog:
return catalog
......@@ -467,8 +473,7 @@ class DaskBulkStorage:
self._fill_catalog_columns_info(catalog, chunk_metas, bulk_id)
)
fcatalog = await self.client.scatter(catalog)
await async_save_bulk_catalog(self._fs, commit_path, fcatalog)
await save_bulk_catalog(self._fs, commit_path, catalog)
trace_attributes_root_span({
'catalog-row-count': catalog.nb_rows,
'catalog-col-count': catalog.all_columns_count
......
......@@ -27,7 +27,7 @@ class DaskWorkerPlugin(WorkerPlugin):
logger.debug("WorkerPlugin initialised")
def setup(self, worker):
init_logger(service_name=Config.service_name.value)
init_logger(service_name=Config.service_name.value, config=Config)
self.worker = worker
if self._register_fsspec_implementation:
......
from logging import Logger
from dask.utils import format_bytes, parse_bytes
from distributed import system
from distributed.deploy.utils import nprocesses_nthreads
from app.conf import Config
class DaskException(Exception):
pass
# Amount of memory Reserved for fastApi server + ProcessPoolExecutors
memory_leeway = parse_bytes("600Mi")
def min_worker_memory_recommended(config: Config):
"""Minimal amount of memory required for a Dask worker to not get bad performances"""
return parse_bytes(config.min_worker_memory.value)
def system_memory():
"""returns the detected memory limit for this system (done by distributed)"""
return system.MEMORY_LIMIT
def available_memory_for_workers():
"""Return amount of RAM available for Dask's workers after withdrawing RAM required by server itself"""
return max(0, (system_memory() - memory_leeway))
def recommended_workers_and_threads():
""" Return the recommended numbers of worker and threads according the cpus available provided by Dask """
return nprocesses_nthreads()
def get_dask_configuration(*, config: Config, logger: Logger):
"""
Return recommended Dask workers configuration
"""
n_workers, threads_per_worker = recommended_workers_and_threads()
available_memory_bytes = available_memory_for_workers()
worker_memory_limit = int(available_memory_bytes / n_workers)
logger.info(
f"Dask client - system.MEMORY_LIMIT: {format_bytes(system_memory())} "
f"- available_memory_bytes: {format_bytes(available_memory_bytes)} "
f"- min_worker_memory_recommended: {format_bytes(min_worker_memory_recommended(config))} "
f"- computed worker_memory_limit: {format_bytes(worker_memory_limit)} for {n_workers} workers"
)
if min_worker_memory_recommended(config) > worker_memory_limit:
n_workers = available_memory_bytes // min_worker_memory_recommended(config)
if not n_workers >= 1:
min_memory = min_worker_memory_recommended(config) + memory_leeway
message = (
f"Not enough memory available to start Dask worker. "
f"Please, consider upgrading container memory to {format_bytes(min_memory)}"
)
logger.error(
f"Dask client - {message} - "
f"n_workers: {n_workers} threads_per_worker: {threads_per_worker}, "
f"available_memory_bytes: {available_memory_bytes} "
)
raise DaskException(message)
worker_memory_limit = available_memory_bytes / n_workers
logger.warning(
f"Dask client - available RAM is too low. Reducing number of workers "
f"to {n_workers} running with {format_bytes(worker_memory_limit)} of RAM"
)
return n_workers, threads_per_worker, worker_memory_limit
......@@ -28,7 +28,7 @@ def wrap_trace_process(*args, **kwargs):
raise AttributeError("Keyword arguments should contain 'target_func' and 'tracing_headers'")
if _EXPORTER is None:
_EXPORTER = traces.create_exporter(service_name=Config.service_name.value)
_EXPORTER = traces.create_exporter(service_name=Config.service_name.value, config=Config)
span_context = traces.get_trace_propagator().from_headers(tracing_headers)
tracer = open_tracer.Tracer(span_context=span_context,
......@@ -95,7 +95,7 @@ def _add_trace_attributes(attributes: dict, tracing_mode: TracingMode):
span = None
if tracing_mode == TracingMode.CURRENT_SPAN:
span = opencensus_tracer.tracer.current_span()
span = opencensus_tracer.current_span()
elif tracing_mode == TracingMode.ROOT_SPAN:
existing_spans = opencensus_tracer.tracer.list_collected_spans()
span = existing_spans[0] if existing_spans else None
......
......@@ -29,7 +29,7 @@ async def resolve_tenant(data_partition_id: str) -> Tenant:
return Tenant(
data_partition_id=data_partition_id,
project_id='',
bucket_name='wdms-osdu'
bucket_name=Config.az_bulk_container
)
if Config.cloud_provider.value == 'ibm':
......
......@@ -18,6 +18,7 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from app import conf
from app.context import Context
from app.helper import utils, traces
from app.helper.logger import get_logger
from .backoff_policy import backoff_policy
from sys import exc_info
from traceback import format_exception
......@@ -45,11 +46,9 @@ def _before_tracing_attributes(ctx, request):
def backoff_handler_log_it(details):
ctx = Context.current()
exception_type, raised_exec, tb = exc_info()
s_stack = format_exception(exception_type, raised_exec, tb)
ctx.logger.exception(f"Backoff callback, tries={details['tries']}: {raised_exec}. Stack = {s_stack}")
get_logger().exception(f"Backoff callback, tries={details['tries']}: {raised_exec}. Stack = {s_stack}")
@backoff_policy(backoff_handler_log_it)
......@@ -68,7 +67,7 @@ async def client_middleware(request, call_next):
tracing_headers = traces.get_trace_propagator().to_headers(span.context_tracer.span_context)
request.headers.update(tracing_headers)
ctx.logger.debug(f"client_middleware - url: {request.url} - tracing_headers: {tracing_headers}")
get_logger().debug(f"client_middleware - url: {request.url} - tracing_headers: {tracing_headers}")
request.headers[conf.AUTHORIZATION_HEADER_NAME] = f'Bearer {ctx.auth}'
if ctx.correlation_id:
......
......@@ -316,6 +316,8 @@ def cloud_provider_additional_environment(config: ConfigurationContainer):
is_mandatory=False,
override=True)
config.az_bulk_container = 'wdms-osdu'
if provider == 'gcp':
config.add_from_env(attribute_name='default_data_tenant_project_id',
env_var_key='OS_WELLBORE_DDMS_DATA_PROJECT_ID',
......
......@@ -16,17 +16,16 @@ import contextvars
from typing import Optional
import json
from app.conf import Config
from app.model.user import User
from app.injector.app_injector import AppInjector
class Context:
"""
Immutable object to provide contextual information a long request processing
"""
__slots__ = [
'_tracer',
'_logger',
'_correlation_id',
'_request_id',
'_dev_mode',
......@@ -42,10 +41,9 @@ class Context:
def __init__(self,
tracer=None,
logger=None,
correlation_id: Optional[str] = None,
request_id: Optional[str] = None,
dev_mode: bool = Config.dev_mode.value,
dev_mode: Optional[bool] = None,
auth=None,
partition_id: Optional[str] = None,
app_key: Optional[str] = None,
......@@ -56,7 +54,6 @@ class Context:
**keys):
self._tracer = tracer
self._logger = logger
self._correlation_id = correlation_id
self._request_id = request_id
self._dev_mode = dev_mode
......@@ -89,9 +86,9 @@ class Context:
Context.__ctx_var.set(self)
@classmethod
def set_current_with_value(cls, tracer=None, logger=None, correlation_id=None, request_id=None, auth=None,
def set_current_with_value(cls, tracer=None, correlation_id=None, request_id=None, auth=None,
partition_id=None, app_key=None, api_key=None, user=None, app_injector=None,
dev_mode=Config.dev_mode.value, x_user_id=None,
dev_mode=None, x_user_id=None,
**keys) -> 'Context':
"""
clone the current context with the given values, set the new ctx as current and returns it
......@@ -100,7 +97,6 @@ class Context:
current = cls.current()
assert current is not None, 'no existing current context'
new_ctx = current.with_value(tracer=tracer,
logger=logger,
correlation_id=correlation_id,
request_id=request_id,
auth=auth,
......@@ -133,7 +129,6 @@ class Context:
def __copy__(self):
return self.__class__(
tracer=self._tracer,
logger=self._logger,
correlation_id=self._correlation_id,
request_id=self._request_id,
dev_mode=self._dev_mode,
......@@ -191,13 +186,12 @@ class Context:
clone._app_injector = app_injector
return clone
def with_value(self, tracer=None, logger=None, correlation_id=None, request_id=None, auth=None,
def with_value(self, tracer=None, correlation_id=None, request_id=None, auth=None,
partition_id=None, app_key=None, api_key=None, user=None, app_injector=None,
dev_mode=Config.dev_mode.value, x_user_id=None, **keys) -> 'Context':
dev_mode=None, x_user_id=None, **keys) -> 'Context':
""" Clone context, adding all keys in future logs """
cloned = self.__class__(
tracer=tracer or self._tracer,
logger=logger or self._logger,
correlation_id=correlation_id or self._correlation_id,
request_id=request_id or self._request_id,
dev_mode=dev_mode or self._dev_mode,
......@@ -218,10 +212,6 @@ class Context:
def tracer(self):
return self._tracer
@property
def logger(self):
return self._logger
@property
def correlation_id(self) -> Optional[str]:
return self._correlation_id
......@@ -231,7 +221,7 @@ class Context:
return self._request_id
@property
def dev_mode(self) -> bool:
def dev_mode(self) -> Optional[bool]:
return self._dev_mode
@property
......@@ -265,7 +255,6 @@ class Context:
def __dict__(self):
return {
"tracer": self.tracer,
"logger": self.logger,
"correlation_id": self.correlation_id,
"request_id": self.request_id,
"dev_mode": self.dev_mode,
......
......@@ -33,6 +33,7 @@ from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from app.context import get_ctx
from app.helper.logger import get_logger
OSDU_DATA_ECOSYSTEM_SEARCH = "osdu-data-ecosystem-search"
OSDU_DATA_ECOSYSTEM_STORAGE = "osdu-data-ecosystem-storage"
......@@ -54,7 +55,7 @@ async def http_search_error_handler(request: Request, exc: OSDUSearchException)
"""
Catches and handles Exceptions raised by os-python-client
"""
get_ctx().logger.exception(f"http_search_error_handler - url: '{request.url}'")
get_logger().exception(f"http_search_error_handler - url: '{request.url}'")
if isinstance(exc, OSDUSearchUnexpectedResponse):
status = exc.status_code
errors = [load_content(exc.content)]
......@@ -75,7 +76,7 @@ async def http_storage_error_handler(request: Request, exc: OSDUStorageException
"""
Catches and handles Exceptions raised by os-python-client
"""
get_ctx().logger.exception(f"http_storage_error_handler - url: '{request.url}'")
get_logger().exception(f"http_storage_error_handler - url: '{request.url}'")
if isinstance(exc, OSDUStorageUnexpectedResponse) or isinstance(exc, OSDUStorageResponseValidationError):
status = exc.status_code
errors = [load_content(exc.content)]
......@@ -93,7 +94,7 @@ async def http_partition_error_handler(request: Request, exc: OSDUPartitionExcep
"""
Catches and handles Exceptions raised by os-python-client
"""
get_ctx().logger.exception(f"http_partition_error_handler - url: '{request.url}'")
get_logger().exception(f"http_partition_error_handler - url: '{request.url}'")
return JSONResponse({"origin": OSDU_DATA_ECOSYSTEM_PARTITION, "errors": [exc.message]},
status_code=exc.status_code)
......@@ -21,7 +21,6 @@ import structlog
from structlog.contextvars import merge_contextvars
from opencensus.trace import config_integration
from app.conf import Config
from app.context import get_or_create_ctx
from app.helper.utils import rename_cloud_role_func
......@@ -87,12 +86,16 @@ class AzureContextLoggerAdapter(logging.LoggerAdapter):
return msg, kwargs
def init_logger(service_name):
def init_logger(*, service_name, config):
global _LOGGER
if Config.cloud_provider.value == 'az':
_LOGGER = create_azure_logger(service_name)
elif Config.cloud_provider.value == 'gcp':
if config.cloud_provider.value == 'az':
_LOGGER = create_azure_logger(
service_name=service_name,
az_ai_instrumentation_key=config.get('az_ai_instrumentation_key'),
az_logger_level=config.get('az_logger_level')
)
elif config.cloud_provider.value == 'gcp':
_LOGGER = create_gcp_logger(service_name)
else:
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
......@@ -101,8 +104,7 @@ def init_logger(service_name):
return _LOGGER
def create_azure_logger(service_name):
from opencensus.ext.azure.log_exporter import AzureLogHandler
def create_azure_logger(*, service_name, az_ai_instrumentation_key, az_logger_level):
"""
Create logger with two handlers:
- AzureLogHandler: to see Dependencies, Requests, Traces and Exception into Azure monitoring
......@@ -117,8 +119,8 @@ def create_azure_logger(service_name):
stdout_handler = logging.StreamHandler(sys.stdout)
# AzurelogHandler for logging to azure appinsight
key = Config.get('az_ai_instrumentation_key')
logger_level = Config.get('az_logger_level')
key = az_ai_instrumentation_key
logger_level = az_logger_level
az_handler = AzureLogHandler(connection_string=f'InstrumentationKey={key}')
az_handler.setLevel(logging.getLevelName(logger_level))
az_handler.add_telemetry_processor(rename_cloud_role_func(service_name))
......
......@@ -17,15 +17,13 @@ from typing import Callable
from fastapi.routing import APIRoute
from opencensus.common.transports.async_ import AsyncTransport
from opencensus.trace import base_exporter
from opencensus.trace import base_exporter, execution_context
from opencensus.trace.propagation.trace_context_http_header_format import TraceContextPropagator
from opencensus.trace.span import SpanKind
from starlette.requests import Request
from starlette.responses import Response
from app.conf import Config
from app.helper.utils import rename_cloud_role_func, azure_traces_processing, COMPONENT
from app.context import get_or_create_ctx
from .utils import rename_cloud_role_func, azure_traces_processing, COMPONENT
"""
......@@ -76,22 +74,22 @@ def _create_gcp_exporter():
return StackdriverExporter(transport=AsyncTransport)
def create_exporter(service_name):
def create_exporter(*, service_name, config):
"""
Create exporters to sent tracing to different tracing platforms e.g. Stackdriver (Google) or Azure
c.f. documentation https://opencensus.io/exporters/supported-exporters/python/
"""
combined_exporter = CombinedExporter(service_name=service_name)
if Config.cloud_provider.value == 'gcp':
if config.cloud_provider.value == 'gcp':
print("Registering OpenCensus trace Stackdriver")
stackdriver_exporter = _create_gcp_exporter()
combined_exporter.add_exporter(stackdriver_exporter)
elif Config.cloud_provider.value == 'az':
elif config.cloud_provider.value == 'az':
print("Registering OpenCensus trace AzureExporter")
key = Config.get('az_ai_instrumentation_key')
key = config.get('az_ai_instrumentation_key')
try:
az_exporter = _create_azure_exporter(key)
az_exporter.add_telemetry_processor(rename_cloud_role_func(service_name))
......@@ -135,7 +133,7 @@ def with_trace(label: str, span_kind=SpanKind.CLIENT):
@wraps(target)
async def async_inner(*args, **kwargs):
tracer = get_or_create_ctx().tracer
tracer = execution_context.get_opencensus_tracer()
if tracer is None:
return await target(*args, **kwargs)
......@@ -147,7 +145,7 @@ def with_trace(label: str, span_kind=SpanKind.CLIENT):
@wraps(target)
def sync_inner(*args, **kwargs):
tracer = get_or_create_ctx().tracer
tracer = execution_context.get_opencensus_tracer()
if tracer is None:
return target(*args, **kwargs)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import uuid
from typing import Optional
from fastapi import Depends, Header
from fastapi.security.api_key import APIKeyHeader
from starlette.middleware.base import BaseHTTPMiddleware
......@@ -28,8 +28,9 @@ from app.helper.logger import get_logger
class CreateBasicContextMiddleware(BaseHTTPMiddleware):
def __init__(self, injector: AppInjector, **kwargs):
def __init__(self, *, config: conf.ConfigurationContainer, injector: Optional[AppInjector], **kwargs):
super().__init__(**kwargs)
self._config = config
self._app_injector = injector
@staticmethod
......@@ -58,9 +59,9 @@ class CreateBasicContextMiddleware(BaseHTTPMiddleware):
api_key=api_key)
Context.clear_current()
ctx = Context(logger=get_logger(),
correlation_id=correlation_id,
ctx = Context(correlation_id=correlation_id,
request_id=request_id,
dev_mode=self._config.get('dev_mode'),
partition_id=partition_id,
app_key=app_key,
api_key=api_key,
......
......@@ -26,6 +26,7 @@ from opencensus.trace.span import SpanKind
from app.helper import traces, utils
from app.context import get_or_create_ctx
from app import conf
from app.helper.logger import get_logger
class TracingMiddleware(BaseHTTPMiddleware):
......@@ -121,17 +122,17 @@ class TracingMiddleware(BaseHTTPMiddleware):
ctx.set_current_with_value(tracer=tracer)
self._before_request(request, tracer)
ctx.logger.debug(f'Request start: {request.method} {request.url}')
get_logger().debug(f'Request start: {request.method} {request.url}')
response = None
try:
response = await call_next(request)
return response
except Exception:
ctx.logger.exception(f"Exception occurred when calling: {request.url.path}")
get_logger().exception(f"Exception occurred when calling: {request.url.path}")
raise
finally:
status = response.status_code if response else HTTP_500_INTERNAL_SERVER_ERROR
if not request.url.path.endswith('healthz'):
ctx.logger.info(utils.process_message(request, status))
get_logger().info(utils.process_message(request, status))
self._after_request(request, response, tracer)
......@@ -25,8 +25,8 @@ from osdu.core.api.storage.tenant import Tenant
from osdu.core.api.storage.exceptions import PreconditionFailedException, ResourceNotFoundException
from app.helper.traces import with_trace
from app.context import Context
from app.utils import capture_timings
from app.helper.logger import get_logger
class SessionState(str, Enum):
......@@ -260,14 +260,14 @@ class SessionsStorage:
internal = await self._get_session(tenant, record_id, session_id)
if not internal.session.is_closed:
Context.current().logger.error(f"Invalid state for session deletion: {internal.session}")
get_logger().error(f"Invalid state for session deletion: {internal.session}")
if not force_delete:
raise RuntimeError("Session cannot be deleted. "
"Invalid state. The session must be completed or abandoned before")
object_name = self._build_session_complete_name(record_id, session_id)
await self._storage.delete(tenant, object_name)
Context.current().logger.debug(f'session deleted: {internal.session}')
get_logger().debug(f'session deleted: {internal.session}')
class CompletionContextManager:
def __init__(self, client: 'SessionsStorage', tenant: Tenant, record_id: str, session_id: str, commit: bool):
......@@ -373,7 +373,7 @@ class SessionsStorage:
f"session cannot be {SessionState.Committing.value}")
# let's continue and finish the session
Context.current().logger.warning(
get_logger().warning(
f"session {i_session.session.id} for record {i_session.session.recordId} "
f"appears idle in state {i_session.session.state} since {i_session.session.updatedTime}."
f" State update allowed, will be {new_state}"
......
......@@ -23,6 +23,7 @@ from app.model.log_bulk import LogBulkHelper
from app.bulk_persistence.dask.traces import trace_dataframe_attributes
from app.helper.traces import with_trace
from app.helper.logger import get_logger
class Persistence:
......@@ -50,5 +51,5 @@ class Persistence:
try:
return await create_and_store_dataframe(ctx, dataframe)
except Exception:
ctx.logger.exception("write_bulk")
get_logger().exception("write_bulk")
raise
......@@ -24,6 +24,7 @@ from app.routers.search import search_wrapper
from app.clients import SearchServiceClient, StorageRecordServiceClient
from app.model.entity_utils import Entity, format_kind, get_kind_meta
from app.context import Context
from app.helper.logger import get_logger
class StorageHelper:
......@@ -83,7 +84,7 @@ class StorageHelper:
# first delete the source entity, if it fail, we must not delete the others
await storage_service.delete_record(id=entity_id, data_partition_id=data_partition_id)
ctx.logger.debug(f'record {entity_id} successfully deleted')
get_logger().debug(f'record {entity_id} successfully deleted')
# execute all deletion concurrently, do not stop at first fail
delete_results = await asyncio.gather(*[
......@@ -102,12 +103,12 @@ class StorageHelper:
# log successfully deleted entities for debugging purposes
for r in filter(lambda r: r.status_code == status.HTTP_200_OK, results):
ctx.logger.debug(f'{r.entity["id"]} of kind {r.entity["kind"]} '
get_logger().debug(f'{r.entity["id"]} of kind {r.entity["kind"]} '
f'successfully deleted (from recursive delete of {entity_id})')
# warn for already deleted entity
for r in filter(lambda r: r.status_code == status.HTTP_404_NOT_FOUND, results):
ctx.logger.warning(f'entity {r.entity["id"]} of kind {r.entity["kind"]} was already deleted')
get_logger().warning(f'entity {r.entity["id"]} of kind {r.entity["kind"]} was already deleted')
# errors treatment (i.e. not 200, not 404), gather them by status
in_errors = list(filter(
......@@ -117,7 +118,7 @@ class StorageHelper:
# log errors
for r in in_errors:
ctx.logger.error(f'error on deleted entity {r.entity["id"]} of kind {r.entity["kind"]},'
get_logger().error(f'error on deleted entity {r.entity["id"]} of kind {r.entity["kind"]},'
f'status code: {r.status_code}, detail: {str(r.result)}')
if len(in_errors) == 1: # a single error, just forward
......