Commit 7cc5d7d9 authored by Alexandre Vincent's avatar Alexandre Vincent
Browse files

extract DaskClient static methods to get rid of unused class

parent 76c72c84
Pipeline #111350 failed with stages
in 59 minutes and 52 seconds
......@@ -26,7 +26,7 @@ from .json_orient import JSONOrient
from .mime_types import MimeTypes, MimeType
from .exceptions import UnknownChannelsException, InvalidBulkException, NoBulkException, NoDataException, RecordNotFoundException
from .consistency_checks import ConsistencyException, DataConsistencyChecks
from .dask.client import DaskClient
from .dask import client as dask_client
from .dask.localcluster import DaskException
from .capture_timings import capture_timings
from .sessions_storage import Session, SessionsStorage, \
......
......@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import contextlib
import asyncio
......@@ -24,80 +26,81 @@ from ..bulk_persistence_config import BulkPersistenceConfig
from .localcluster import get_dask_configuration
HOUR = 3600 # in seconds
_HOUR = 3600 # in seconds
_client: Optional[DaskDistributedClient] = None
_cluster: Optional[LocalCluster] = None
# Ensure access to critical section is done for only one coroutine
_lock_client: Optional[asyncio.Lock] = None
async def create(config: BulkPersistenceConfig) -> DaskDistributedClient:
global _lock_client, _client, _cluster
class DaskClient:
if not _lock_client:
_lock_client = asyncio.Lock()
client: DaskDistributedClient = None
cluster: LocalCluster = None
async with _lock_client:
if not _client:
from app.helper.logger import get_logger
logger = get_logger()
logger.info(f"Dask client initialization started...")
# Ensure access to critical section is done for only one coroutine
lock_client: asyncio.Lock = None
n_workers, threads_per_worker, worker_memory_limit = get_dask_configuration(config=config, logger=logger)
logger.info(f"Dask client worker configuration: {n_workers} workers running with "
f"{format_bytes(worker_memory_limit)} of RAM and {threads_per_worker} threads each")
@staticmethod
async def create(config: BulkPersistenceConfig) -> DaskDistributedClient:
if not DaskClient.lock_client:
DaskClient.lock_client = asyncio.Lock()
# Ensure memory used by workers is freed regularly despite memory leak
dask.config.set({'distributed.worker.lifetime.duration': _HOUR * 24})
dask.config.set({'distributed.worker.lifetime.stagger': _HOUR * 1})
dask.config.set({'distributed.worker.lifetime.restart': True})
logger.info(f"Dask cluster configuration - "
f"worker lifetime: {dask.config.get('distributed.worker.lifetime.duration')}s. "
f"stagger: {dask.config.get('distributed.worker.lifetime.stagger')}s.")
if not DaskClient.client:
async with DaskClient.lock_client:
if not DaskClient.client:
from app.helper.logger import get_logger
logger = get_logger()
logger.info(f"Dask client initialization started...")
_cluster = await LocalCluster(
asynchronous=True,
processes=True,
threads_per_worker=threads_per_worker,
n_workers=n_workers,
memory_limit=worker_memory_limit,
dashboard_address=None
)
n_workers, threads_per_worker, worker_memory_limit = get_dask_configuration(config=config, logger=logger)
logger.info(f"Dask client worker configuration: {n_workers} workers running with "
f"{format_bytes(worker_memory_limit)} of RAM and {threads_per_worker} threads each")
# A worker could be killed when executing a task if lifetime duration elapsed,
# "cluster.adapt(min=N, max=N)" ensure the respawn of workers if it happens
_cluster.adapt(minimum=n_workers, maximum=n_workers)
_client = await DaskDistributedClient(_cluster, asynchronous=True)
# Ensure memory used by workers is freed regularly despite memory leak
dask.config.set({'distributed.worker.lifetime.duration': HOUR * 24})
dask.config.set({'distributed.worker.lifetime.stagger': HOUR * 1})
dask.config.set({'distributed.worker.lifetime.restart': True})
logger.info(f"Dask cluster configuration - "
f"worker lifetime: {dask.config.get('distributed.worker.lifetime.duration')}s. "
f"stagger: {dask.config.get('distributed.worker.lifetime.stagger')}s.")
get_logger().info(f"Dask client initialized : {_client}")
DaskClient.cluster = await LocalCluster(
asynchronous=True,
processes=True,
threads_per_worker=threads_per_worker,
n_workers=n_workers,
memory_limit=worker_memory_limit,
dashboard_address=None
)
return _client
# A worker could be killed when executing a task if lifetime duration elapsed,
# "cluster.adapt(min=N, max=N)" ensure the respawn of workers if it happens
DaskClient.cluster.adapt(minimum=n_workers, maximum=n_workers)
DaskClient.client = await DaskDistributedClient(DaskClient.cluster, asynchronous=True)
get_logger().info(f"Dask client initialized : {DaskClient.client}")
return DaskClient.client
async def close():
global _cluster, _client
@staticmethod
async def close():
if not DaskClient.lock_client:
return
if not _lock_client:
return
async with DaskClient.lock_client:
if DaskClient.cluster:
# explicitly closing the cluster is necessary
# since it has been started independently from the client
await DaskClient.cluster.close()
DaskClient.cluster = None
async with _lock_client:
if _cluster:
# explicitly closing the cluster is necessary
# since it has been started independently from the client
await _cluster.close()
_cluster = None
if DaskClient.client:
await DaskClient.client.close() # or shutdown
DaskClient.client = None
if _client:
await _client.close() # or shutdown
_client = None
@contextlib.asynccontextmanager
async def actx(config: BulkPersistenceConfig):
try:
client = await DaskClient.create(config)
client = await create(config)
yield client
finally:
await DaskClient.close()
await close()
......@@ -31,7 +31,7 @@ from ..capture_timings import capture_timings
from ..sessions_storage import Session
from ..bulk_persistence_config import BulkPersistenceConfig
from .client import DaskClient
from .client import create as dask_client_create
from .dask_worker_plugin import DaskWorkerPlugin
from .errors import BulkRecordNotFound, BulkNotProcessable, internal_bulk_exceptions
from .traces import map_with_trace, submit_with_trace, trace_attributes_root_span, trace_attributes_current_span
......@@ -111,7 +111,7 @@ class DaskBulkStorage:
instance._parameters = parameters
# Initialise the dask client.
dask_client = dask_client or await DaskClient.create(config)
dask_client = dask_client or await dask_client_create(config)
if DaskBulkStorage.client is not dask_client: # executed only once per dask client
DaskBulkStorage.client = dask_client
......
......@@ -58,7 +58,7 @@ from app.utils import (
get_http_client_session,
OpenApiHandler,
POOL_EXECUTOR_MAX_WORKER)
from app.bulk_persistence import DaskClient, BulkPersistenceConfig, set_config_getter
from app.bulk_persistence import dask_client, BulkPersistenceConfig, set_config_getter
from app.routers.bulk.utils import (
update_operation_ids,
set_v3_input_dataframe_check,
......@@ -148,7 +148,7 @@ async def startup_event():
# seems that the lock is not in the same event loop as requests
# so we need to wait instead of just fire a task
asyncio.create_task(DaskClient.create(bulk_config))
asyncio.create_task(dask_client.create(bulk_config))
create_custom_http_exception_handler(wdms_app, logger)
# init executor pool
logger.get_logger().info("Startup process pool executor")
......@@ -173,7 +173,7 @@ async def shutdown_event():
await search_client.api_client.close()
await get_http_client_session().close()
await DaskClient.close()
await dask_client.close()
DDMS_V2_PATH = '/ddms/v2'
......
......@@ -7,7 +7,7 @@ import types
import pytest
from unittest import mock
import app.bulk_persistence.dask.client as dask_client
from app.bulk_persistence import dask_client
@pytest.fixture
......@@ -16,6 +16,7 @@ def local_cluster_mock():
# we need to mock both the class and the instance given current implementation
mock_cluster_instance = mock.AsyncMock()
mock_cluster_instance.close = mock.AsyncMock()
mock_cluster_instance.adapt = mock.Mock()
with mock.patch("app.bulk_persistence.dask.client.LocalCluster",
new=mock.AsyncMock(return_value=mock_cluster_instance)) as mock_cluster:
yield mock_cluster, mock_cluster_instance
......@@ -46,12 +47,12 @@ async def test_dask_client_create_close_idempotent_sync(nope_logger_fixture, loc
client_mock, client_instance_mock = dask_distributed_client_mock
# first call
client = await dask_client.DaskClient.create(local_bulk_persistence_config)
client = await dask_client.create(local_bulk_persistence_config)
assert_await_called_once(cluster_mock)
assert_await_called_once(client_mock)
# another call
same_client = await dask_client.DaskClient.create(local_bulk_persistence_config)
same_client = await dask_client.create(local_bulk_persistence_config)
# init has NOT been called again
assert_await_called_once(cluster_mock)
......@@ -61,12 +62,12 @@ async def test_dask_client_create_close_idempotent_sync(nope_logger_fixture, loc
assert id(client) == id(same_client)
# close call
await dask_client.DaskClient.close()
await dask_client.close()
assert_await_called_once(cluster_instance_mock.close)
assert_await_called_once(client_instance_mock.close)
# another close call
await dask_client.DaskClient.close()
await dask_client.close()
# close has NOT been called again
assert_await_called_once(cluster_instance_mock.close)
......@@ -84,14 +85,14 @@ def test_dask_client_create_close_idempotent_async(nope_logger_fixture, local_cl
# running n start in parallel
loop.run_until_complete(asyncio.gather(
dask_client.DaskClient.create(local_bulk_persistence_config),
dask_client.DaskClient.create(local_bulk_persistence_config),
dask_client.DaskClient.create(local_bulk_persistence_config),
dask_client.DaskClient.create(local_bulk_persistence_config),
dask_client.DaskClient.create(local_bulk_persistence_config),
dask_client.DaskClient.create(local_bulk_persistence_config),
dask_client.DaskClient.create(local_bulk_persistence_config),
dask_client.DaskClient.create(local_bulk_persistence_config)
dask_client.create(local_bulk_persistence_config),
dask_client.create(local_bulk_persistence_config),
dask_client.create(local_bulk_persistence_config),
dask_client.create(local_bulk_persistence_config),
dask_client.create(local_bulk_persistence_config),
dask_client.create(local_bulk_persistence_config),
dask_client.create(local_bulk_persistence_config),
dask_client.create(local_bulk_persistence_config)
))
assert_await_called_once(cluster_mock)
......@@ -99,12 +100,12 @@ def test_dask_client_create_close_idempotent_async(nope_logger_fixture, local_cl
# running m stop in parallel
loop.run_until_complete(asyncio.gather(
dask_client.DaskClient.close(),
dask_client.DaskClient.close(),
dask_client.DaskClient.close(),
dask_client.DaskClient.close(),
dask_client.DaskClient.close(),
dask_client.DaskClient.close()
dask_client.close(),
dask_client.close(),
dask_client.close(),
dask_client.close(),
dask_client.close(),
dask_client.close()
))
assert_await_called_once(cluster_instance_mock.close)
......
import pytest
from app.bulk_persistence import DaskException, DaskClient
from app.bulk_persistence import DaskException, dask_client
from app.bulk_persistence.dask.localcluster import memory_leeway
import app.bulk_persistence.dask.client as dask_client
@pytest.mark.asyncio
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment