Skip to content
Snippets Groups Projects
Commit ca03e3a1 authored by Jeremie Hallal's avatar Jeremie Hallal
Browse files

Improve dask init

parent 0579a0f3
No related branches found
No related tags found
3 merge requests!202exclusiveMinimum and maximum does not make sense with ValueWithUnit object, it...,!201exclusiveMinimum and maximum does not make sense with ValueWithUnit object, it...,!191Improve dask init
......@@ -12,35 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import hashlib
import json
import time
from contextlib import suppress
from functools import wraps
from operator import attrgetter
import fsspec
import pandas as pd
from pyarrow.lib import ArrowException
import dask
import dask.dataframe as dd
from dask.distributed import Client as DaskDistributedClient, WorkerPlugin
from osdu.core.api.storage.dask_storage_parameters import DaskStorageParameters
from app.bulk_persistence import BulkId
from app.bulk_persistence.dask.traces import wrap_trace_process
from app.bulk_persistence.dask.errors import BulkNotFound, BulkNotProcessable
from app.bulk_persistence.dask.traces import wrap_trace_process
from app.bulk_persistence.dask.utils import (SessionFileMeta, by_pairs,
do_merge, set_index,
worker_capture_timing_handlers)
from app.helper.logger import get_logger
from app.helper.traces import with_trace
from app.persistence.sessions_storage import Session
from app.utils import capture_timings, get_wdms_temp_dir, get_ctx
from app.utils import DaskClient, capture_timings, get_ctx
from osdu.core.api.storage.dask_storage_parameters import DaskStorageParameters
from pyarrow.lib import ArrowException
dask.config.set({'temporary_directory': get_wdms_temp_dir()})
import dask.dataframe as dd
from dask.distributed import Client as DaskDistributedClient
from dask.distributed import WorkerPlugin
def handle_pyarrow_exceptions(target):
......@@ -63,8 +59,8 @@ class DefaultWorkerPlugin(WorkerPlugin):
_LOGGER = logger
self._register_fsspec_implementation = register_fsspec_implementation
get_logger().debug("WorkerPlugin initialised")
super().__init__()
get_logger().debug("WorkerPlugin initialised")
def setup(self, worker):
self.worker = worker
......@@ -77,13 +73,14 @@ class DefaultWorkerPlugin(WorkerPlugin):
get_logger().exception(f"Task '{key}' has failed with exception")
def pandas_to_parquet(pdf, path, opt):
return pdf.to_parquet(path, index=True, engine='pyarrow', storage_options=opt)
class DaskBulkStorage:
client = None
client: DaskDistributedClient = None
""" Dask client """
lock_client = asyncio.Lock()
""" used to ensure """
def __init__(self):
""" use `create` to create instance """
self._parameters = None
......@@ -95,20 +92,20 @@ class DaskBulkStorage:
instance._parameters = parameters
# Initialise the dask client.
async with DaskBulkStorage.lock_client:
if not DaskBulkStorage.client:
DaskBulkStorage.client = dask_client or await DaskDistributedClient(asynchronous=True, processes=True)
dask_client = dask_client or await DaskClient.create()
if DaskBulkStorage.client is not dask_client: # executed only once per dask client
DaskBulkStorage.client = dask_client
if parameters.register_fsspec_implementation:
parameters.register_fsspec_implementation()
if parameters.register_fsspec_implementation:
parameters.register_fsspec_implementation()
await DaskBulkStorage.client.register_worker_plugin(
DefaultWorkerPlugin,
name="LoggerWorkerPlugin",
logger=get_logger(),
register_fsspec_implementation=parameters.register_fsspec_implementation)
await DaskBulkStorage.client.register_worker_plugin(
DefaultWorkerPlugin,
name="LoggerWorkerPlugin",
logger=get_logger(),
register_fsspec_implementation=parameters.register_fsspec_implementation)
get_logger().info(f"Distributed Dask client initialized : {DaskBulkStorage.client}")
get_logger().info(f"Distributed Dask client initialized : {DaskBulkStorage.client}")
instance._fs = fsspec.filesystem(parameters.protocol, **parameters.storage_options)
return instance
......@@ -121,12 +118,6 @@ class DaskBulkStorage:
def base_directory(self) -> str:
return self._parameters.base_directory
@staticmethod
async def close(): # TODO check for the needs, currently not usage
async with DaskBulkStorage.lock_client:
if DaskBulkStorage.client:
await DaskBulkStorage.client.close() # or shutdown
DaskBulkStorage.client = None
def _encode_record_id(self, record_id: str) -> str:
return hashlib.sha1(record_id.encode()).hexdigest()
......@@ -149,7 +140,6 @@ class DaskBulkStorage:
path : string or list
**kwargs: dict (of dicts) Passthrough key-word arguments for read backend.
"""
get_logger().debug(f"loading bulk : {path}")
return self._submit_with_trace(dd.read_parquet, path,
engine='pyarrow-dataset',
storage_options=self._parameters.storage_options,
......@@ -199,15 +189,17 @@ class DaskBulkStorage:
engine='pyarrow',
storage_options=self._parameters.storage_options)
def _save_with_pandas(self, path, pdf: dd.DataFrame):
async def _save_with_pandas(self, path, pdf: dd.DataFrame):
"""Save the dataframe to a parquet file(s).
pdf: pd.DataFrame or Future<pd.DataFrame>
returns a Future<None>
"""
return self._submit_with_trace(pdf.to_parquet, path,
engine='pyarrow',
storage_options=self._parameters.storage_options)
f_pdf = await self.client.scatter(pdf)
return await self._submit_with_trace(pandas_to_parquet, f_pdf, path,
self._parameters.storage_options)
def _check_incoming_chunk(self, df):
# TODO should we test if is_monotonic?, unique ?
if len(df.index) == 0:
......@@ -223,13 +215,12 @@ class DaskBulkStorage:
@capture_timings('save_blob', handlers=worker_capture_timing_handlers)
async def save_blob(self, ddf: dd.DataFrame, record_id: str, bulk_id: str = None):
"""Write the data frame to the blob storage."""
# TODO: The new bulk_id should contain information about the way we store the bulk
# In the future, if we change the way we store chunk it could be useful to deduce it from the bulk_uri
bulk_id = bulk_id or BulkId.new_bulk_id()
if isinstance(ddf, pd.DataFrame):
self._check_incoming_chunk(ddf)
ddf = dd.from_pandas(ddf, npartitions=1)
ddf = await self.client.scatter(ddf)
path = self._get_blob_path(record_id, bulk_id)
try:
......@@ -261,14 +252,9 @@ class DaskBulkStorage:
with self._fs.open(f'{session_path_wo_protocol}/{filename}.meta', 'w') as outfile:
json.dump({"columns": list(pdf)}, outfile)
# could be done asynchronously in the workers but it as a cost
# we may want to be async if the dataFrame is big
session_path = self._build_path_from_session(session)
# await self._save_with_pandas(f'{session_path}/{filename}.parquet', pdf)
# TODO: Warning this is a sync CPU bound operation
pdf.to_parquet(f'{session_path}/{filename}.parquet', index=True,
storage_options=self._parameters.storage_options, engine='pyarrow')
await self._save_with_pandas(f'{session_path}/{filename}.parquet', pdf)
@capture_timings('get_session_parquet_files')
@with_trace('get_session_parquet_files')
......@@ -313,7 +299,9 @@ class DaskBulkStorage:
if not dfs:
raise BulkNotProcessable("No data to commit")
dfs = self._map_with_trace(set_index, dfs)
if len(dfs) > 1: # set_index is not needed if no merge operations are done
dfs = self._map_with_trace(set_index, dfs)
while len(dfs) > 1:
dfs = [self._submit_with_trace(do_merge, a, b) for a, b in by_pairs(dfs)]
......
......@@ -74,7 +74,8 @@ class SessionFileMeta:
def set_index(ddf): # TODO
"""Set index of the dask dataFrame only if needed."""
if not ddf.known_divisions or '_idx' not in ddf:
ddf['_idx'] = ddf.index # we need to create a temporary variable to set it as index
if '_idx' not in ddf:
ddf['_idx'] = ddf.index # we need to create a temporary variable to set it as index
ddf['_idx'] = ddf['_idx'].astype(ddf.index.dtype)
return ddf.set_index('_idx', sorted=True)
return ddf
......
......@@ -14,10 +14,12 @@
import uuid
from asyncio import gather, iscoroutinefunction
from typing import List
from app.model import model_utils
from fastapi import FastAPI, HTTPException, status
from odes_storage.models import *
from fastapi import HTTPException, status
from odes_storage.models import (CreateUpdateRecordsResponse, Record,
RecordVersions)
from osdu.core.api.storage.blob_storage_base import BlobStorageBase
from osdu.core.api.storage.exceptions import ResourceNotFoundException
from osdu.core.api.storage.tenant import Tenant
......@@ -116,11 +118,11 @@ class StorageRecordServiceBlobStorage:
skipped_record_ids=[])
async def get_record_version(self,
id: str,
version: int,
data_partition_id: str = None,
appkey: str = None,
token: str = None) -> Record:
id: str,
version: int,
data_partition_id: str = None,
appkey: str = None,
token: str = None) -> Record:
await self._check_auth(appkey, token)
try:
object_name = await self._build_record_path(id, data_partition_id, version=version)
......
import re
from app.converter.converter_utils import ConverterUtils
from typing import Tuple
OSDU_WELLBORE_VERSION_REGEX = re.compile(r'^([\w\-\.]+:master-data\-\-Wellbore:[\w\-\.\:\%]+):([0-9]*)$')
OSDU_WELLBORE_REGEX = re.compile(r'^[\w\-\.]+:master-data\-\-Wellbore:[\w\-\.\:\%]+$')
......@@ -18,7 +19,7 @@ class DMSV3RouterUtils:
return OSDU_WELL_REGEX.match(entity_id) is not None
@staticmethod
def is_osdu_versionned_entity_id(entity_regexp, entity_id: str) -> (bool, str, str):
def is_osdu_versionned_entity_id(entity_regexp, entity_id: str) -> Tuple[bool, str, str]:
"""
:param entity_regexp: regexp to test the entity (one regexp per entity)
:param entity_id: id of the entity to test
......@@ -32,11 +33,11 @@ class DMSV3RouterUtils:
return True, matches.group(1), matches.group(2)
@staticmethod
def is_osdu_versionned_wellbore_id(entity_id: str) -> (bool, str, str):
def is_osdu_versionned_wellbore_id(entity_id: str) -> Tuple[bool, str, str]:
return DMSV3RouterUtils.is_osdu_versionned_entity_id(OSDU_WELLBORE_VERSION_REGEX, entity_id)
@staticmethod
def is_osdu_versionned_well_id(entity_id: str) -> (bool, str, str):
def is_osdu_versionned_well_id(entity_id: str) -> Tuple[bool, str, str]:
return DMSV3RouterUtils.is_osdu_versionned_entity_id(OSDU_WELL_VERSION_REGEX, entity_id)
@staticmethod
......@@ -44,7 +45,7 @@ class DMSV3RouterUtils:
return DELFI_REGEX.match(entity_id) is not None
@staticmethod
def is_osdu_entity_fake_id(entity_id: str) -> (bool, str):
def is_osdu_entity_fake_id(entity_id: str) -> Tuple[bool, str]:
try:
delfi_id = ConverterUtils.decode_id(entity_id)
return DMSV3RouterUtils.is_delfi_id(delfi_id), delfi_id
......
......@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Optional, Callable, List, Tuple, Union, NamedTuple
import concurrent.futures
from functools import lru_cache, wraps
from functools import lru_cache, wraps, partial
from aiohttp import ClientSession
import contextvars
from os import path, makedirs
......@@ -27,6 +28,8 @@ from app.injector.app_injector import AppInjector
from app.conf import Config
from time import perf_counter, process_time
from logging import INFO
import dask
from dask.distributed import Client as DaskDistributedClient
@lru_cache()
......@@ -37,16 +40,54 @@ def get_http_client_session(key: str = 'GLOBAL'):
POOL_EXECUTOR_MAX_WORKER = 4
class DaskClient:
client: DaskDistributedClient = None
""" Dask client """
lock_client: asyncio.Lock = None
""" used to ensure """
@staticmethod
async def create() -> DaskDistributedClient:
if not DaskClient.lock_client:
DaskClient.lock_client = asyncio.Lock()
if not DaskClient.client:
async with DaskClient.lock_client:
if not DaskClient.client:
from app.helper.logger import get_logger
get_logger().info(f"Dask client initialization started...")
DaskClient.client = await DaskDistributedClient(asynchronous=True,
processes=True,
dashboard_address=None,
diagnostics_port=None,
)
get_logger().info(f"Dask client initialized : {DaskClient.client}")
return DaskClient.client
@staticmethod
async def close():
async with DaskClient.lock_client:
if DaskClient.client:
await DaskClient.client.close() # or shutdown
DaskClient.client = None
def get_pool_executor():
if get_pool_executor._pool is None:
get_pool_executor._pool = concurrent.futures.ProcessPoolExecutor(POOL_EXECUTOR_MAX_WORKER)
return get_pool_executor._pool
get_pool_executor._pool = None
async def run_in_pool_executor(func, *args, **kwargs):
pool = get_pool_executor()
loop = asyncio.get_running_loop()
func = partial(func, *args, **kwargs)
return await loop.run_in_executor(pool, func=func)
def _setup_temp_dir() -> str:
tmpdir = tempfile.gettempdir()
if not tmpdir.endswith('wdmsosdu'):
......@@ -469,3 +510,5 @@ class __OpenApiHandler:
OpenApiHandler = __OpenApiHandler()
dask.config.set({'temporary_directory': get_wdms_temp_dir()})
......@@ -54,7 +54,8 @@ from app.utils import (
get_http_client_session,
OpenApiHandler,
get_wdms_temp_dir,
get_pool_executor,
run_in_pool_executor,
DaskClient,
POOL_EXECUTOR_MAX_WORKER)
from app.routers.bulk.utils import update_operation_ids, set_v3_input_dataframe_check, set_legacy_input_dataframe_check
......@@ -114,14 +115,16 @@ async def startup_event():
MainInjector().configure(app_injector)
wdms_app.trace_exporter = traces.create_exporter(service_name=service_name)
# seems that the lock is not in the same event loop as requests
# so we need to wait instead of just fire a task
asyncio.create_task(DaskClient.create())
# init executor pool
logger.get_logger().info("Startup process pool executor")
# force to adjust process count now instead of on first demand
pool = get_pool_executor()
loop = asyncio.get_running_loop()
futures = [loop.run_in_executor(pool, executor_startup_task) for _ in range(POOL_EXECUTOR_MAX_WORKER)]
await asyncio.gather(*futures)
for _ in range(POOL_EXECUTOR_MAX_WORKER):
asyncio.create_task(run_in_pool_executor(executor_startup_task))
if Config.alpha_feature_enabled.value:
enable_alpha_feature()
......@@ -141,6 +144,7 @@ async def shutdown_event():
await storage_client.api_client.close()
await get_http_client_session().close()
await DaskClient.close()
DDMS_V2_PATH = '/ddms/v2'
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from app.utils import DaskClient
import asyncio
from datetime import datetime, timedelta
from tempfile import TemporaryDirectory
......@@ -50,7 +51,7 @@ def event_loop(): # all tests will share the same loop
loop = asyncio.get_event_loop()
yield loop
# teardown
loop.run_until_complete(DaskBulkStorage.close())
loop.run_until_complete(DaskClient.close())
loop.close()
......
import asyncio
import io
from tempfile import TemporaryDirectory
......@@ -20,7 +21,7 @@ from app.clients.storage_service_blob_storage import StorageRecordServiceBlobSto
from app.auth.auth import require_opendes_authorized_user
from app.middleware import require_data_partition_id
from app.helper import traces
from app.utils import Context
from app.utils import Context, DaskClient
from app import conf
from tests.unit.persistence.dask_blob_storage_test import generate_df
......@@ -131,6 +132,15 @@ def init_fixtures(nope_logger_fixture, monkeypatch):
yield
@pytest.fixture(scope="module")
def event_loop(): # all tests will share the same loop
loop = asyncio.get_event_loop()
yield loop
# teardown
loop.run_until_complete(DaskClient.close())
loop.close()
@pytest.fixture
def dasked_test_app(init_fixtures):
from app.wdms_app import wdms_app, enable_alpha_feature
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment