Commit 3a436429 authored by Luc Yriarte's avatar Luc Yriarte
Browse files

Merge branch 'master' into bulk-api-v2

parents 746cb37a f5382f0e
Pipeline #48356 passed with stages
in 12 minutes and 28 seconds
......@@ -14,7 +14,7 @@
from .bulk_id import BulkId
from .dataframe_persistence import create_and_store_dataframe, get_dataframe
from .dataframe_serializer import DataframeSerializer
from .dataframe_serializer import DataframeSerializerAsync, DataframeSerializerSync
from .json_orient import JSONOrient
from .mime_types import MimeTypes
from .tenant_provider import resolve_tenant
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import json
import asyncio
from io import BytesIO
from typing import Union, AnyStr, IO, Optional, List
......@@ -24,9 +25,11 @@ from pandas import DataFrame as DataframeClass
from .json_orient import JSONOrient
from .mime_types import MimeTypes
from app.utils import get_pool_executor
from ..helper.traces import with_trace
class DataframeSerializer:
class DataframeSerializerSync:
"""
the goal is to encapsulate to (de)serialized dataframe from/to various format
then provide unified the way to handle various topics float/double precision, compression etc...
......@@ -82,7 +85,7 @@ class DataframeSerializer:
return df.fillna("NaN").to_json(path_or_buf=path_or_buf, orient=orient.value)
@classmethod
def read_parquet(cls, data) -> 'DataframeSerializer.DataframeClass':
def read_parquet(cls, data) -> 'DataframeSerializerAsync.DataframeClass':
"""
:param data: bytes, path object or file-like object
:return: dataframe
......@@ -94,7 +97,7 @@ class DataframeSerializer:
return pd.read_parquet(data)
@classmethod
def read_json(cls, data, orient: Union[str, JSONOrient]) -> 'DataframeSerializer.DataframeClass':
def read_json(cls, data, orient: Union[str, JSONOrient]) -> 'DataframeSerializerAsync.DataframeClass':
"""
:param data: bytes str content (valid JSON str), path object or file-like object
:param orient:
......@@ -103,3 +106,29 @@ class DataframeSerializer:
orient = JSONOrient.get(orient)
return pd.read_json(path_or_buf=data, orient=orient.value).replace("NaN", np.NaN)
class DataframeSerializerAsync:
def __init__(self, pool_executor=get_pool_executor()):
self.executor = pool_executor
@with_trace("JSON bulk serialization")
async def to_json(self,
df: DataframeClass,
orient: Union[str, JSONOrient] = JSONOrient.split,
path_or_buf: Optional[Union[str, Path, IO[AnyStr]]] = None) -> Optional[str]:
return await asyncio.get_event_loop().run_in_executor(
self.executor, DataframeSerializerSync.to_json, df, orient, path_or_buf
)
@with_trace("Parquet bulk deserialization")
async def read_parquet(self, data) -> DataframeClass:
return await asyncio.get_event_loop().run_in_executor(
self.executor, DataframeSerializerSync.read_parquet, data
)
@with_trace("Parquet JSON deserialization")
async def read_json(self, data, orient: Union[str, JSONOrient]) -> DataframeClass:
return await asyncio.get_event_loop().run_in_executor(
self.executor, DataframeSerializerSync.read_json, data, orient
)
......@@ -15,12 +15,13 @@
import asyncio
from typing import List, Set, Optional
import re
from contextlib import suppress
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
import pandas as pd
from app.bulk_persistence import DataframeSerializer, JSONOrient
from app.bulk_persistence import DataframeSerializerAsync, JSONOrient
from app.bulk_persistence.bulk_id import BulkId
from app.bulk_persistence.dask.dask_bulk_storage import DaskBulkStorage
from app.bulk_persistence.dask.errors import BulkError, BulkNotFound
......@@ -54,7 +55,7 @@ async def get_df_from_request(request: Request, orient: Optional[str] = None) ->
if MimeTypes.PARQUET.match(ct):
content = await request.body() # request.stream()
try:
return DataframeSerializer.read_parquet(content)
return await DataframeSerializerAsync().read_parquet(content)
except OSError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f'{err}') # TODO
......@@ -62,7 +63,7 @@ async def get_df_from_request(request: Request, orient: Optional[str] = None) ->
if MimeTypes.JSON.match(ct):
content = await request.body() # request.stream()
try:
return DataframeSerializer.read_json(content, orient)
return await DataframeSerializerAsync().read_json(content, orient)
except ValueError:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail='invalid body') # TODO
......@@ -91,26 +92,33 @@ class DataFrameRender:
driver = await with_dask_blob_storage()
return await driver.client.submit(lambda: len(df.index))
re_1D_curve_selection = re.compile(r'\[(?P<index>[0-9]+)\]$')
re_2D_curve_selection = re.compile(r'\[(?P<range>[0-9]+:[0-9]+)\]$')
re_array_selection = re.compile(r'^(?P<name>.+)\[(?P<start>[^:]+):?(?P<stop>.*)\]$')
@staticmethod
def _col_matching(sel, col):
if sel == col: # exact match
return True
m_col = DataFrameRender.re_array_selection.match(col)
if not m_col: # if the column doesn't have an array pattern (col[*])
return False
# compare selection with curve name without array suffix [*]
if sel == m_col['name']: # if selection is 'c', c[*] should match
return True
# range selection use cases c[0:2] should match c[0], c[1] and c[2]
m_sel = DataFrameRender.re_array_selection.match(sel)
if m_sel and m_sel['stop']:
with suppress(ValueError): # suppress int conversion exceptions
if int(m_sel['start']) <= int(m_col['start']) <= int(m_sel['stop']):
return True
return False
@staticmethod
def get_matching_column(selection: List[str], cols: Set[str]) -> List[str]:
selected = set()
for to_find in selection:
m = DataFrameRender.re_2D_curve_selection.search(to_find)
if m:
r = range(*map(int, m['range'].split(':')))
def is_matching(c):
if c == to_find:
return True
i = DataFrameRender.re_1D_curve_selection.search(c)
return i and int(i['index']) in r
else:
def is_matching(c):
return c == to_find or to_find == DataFrameRender.re_1D_curve_selection.sub('', c)
selected.update(filter(is_matching, cols.difference(selected)))
for sel in selection:
selected.update(filter(lambda col: DataFrameRender._col_matching(sel, col),
cols.difference(selected)))
return list(selected)
@staticmethod
......
......@@ -3,14 +3,14 @@ from app.conf import Config
from httpx import (
RemoteProtocolError,
TimeoutException, # => ReadTimeout, WriteTimeout, ConnectTimeout, PoolTimeout
TimeoutException) # => ReadTimeout, WriteTimeout, ConnectTimeout, PoolTimeout
)
from odes_storage.exceptions import ResponseHandlingException
def backoff_policy(on_backoff_handlers=None):
return backoff.on_exception(backoff.expo,
(RemoteProtocolError, TimeoutException),
(RemoteProtocolError, TimeoutException, ResponseHandlingException),
max_tries=Config.de_client_backoff_max_tries.value,
on_backoff=on_backoff_handlers,
base=0.5,
......
......@@ -23,7 +23,7 @@ from opencensus.trace.span import SpanKind
from app.conf import Config
from app.helper.utils import rename_cloud_role_func, COMPONENT
from app.utils import Context
from app.utils import get_or_create_ctx
"""
How to add specific span in a method
......@@ -115,11 +115,11 @@ def with_trace(label: str, span_kind=SpanKind.CLIENT):
@wraps(target)
async def async_inner(*args, **kwargs):
tracer = Context.current().tracer
tracer = get_or_create_ctx().tracer
if tracer is None:
return await target(*args, **kwargs)
with Context.current().tracer.span(name=label) as span:
with tracer.span(name=label) as span:
span.span_kind = span_kind
return await target(*args, **kwargs)
......@@ -127,11 +127,11 @@ def with_trace(label: str, span_kind=SpanKind.CLIENT):
@wraps(target)
def sync_inner(*args, **kwargs):
tracer = Context.current().tracer
tracer = get_or_create_ctx().tracer
if tracer is None:
return target(*args, **kwargs)
with Context.current().tracer.span(name=label) as span:
with tracer.span(name=label) as span:
span.span_kind = span_kind
return target(*args, **kwargs)
......
......@@ -37,7 +37,7 @@ from odes_storage.models import (
)
from pydantic import BaseModel, Field
from app.bulk_persistence import DataframeSerializer, JSONOrient, MimeTypes, get_dataframe
from app.bulk_persistence import DataframeSerializerAsync, DataframeSerializerSync, JSONOrient, MimeTypes, get_dataframe
from app.clients.storage_service_client import get_storage_record_service
from app.model.log_bulk import LogBulkHelper
from app.model.model_curated import log
......@@ -224,7 +224,7 @@ _log_dataframe_example = pd.DataFrame(
'.\n Here\'re examples for data with {} rows and {} columns with different _orient_: '.format(
_log_dataframe_example.shape[0],
_log_dataframe_example.shape[1]) +
''.join([f'\n* {o.value}: <br/>`{DataframeSerializer.to_json(_log_dataframe_example, o)}`<br/>&nbsp;'
''.join([f'\n* {o.value}: <br/>`{DataframeSerializerSync.to_json(_log_dataframe_example, o)}`<br/>&nbsp;'
for o in JSONOrient]),
# put examples here because of bug in swagger UI to properly render multiple examples
'required': True,
......@@ -233,9 +233,9 @@ _log_dataframe_example = pd.DataFrame(
'schema': {
# swagger UI bug, so single example here
'example': json.loads(
DataframeSerializer.to_json(_log_dataframe_example, JSONOrient.split)
DataframeSerializerSync.to_json(_log_dataframe_example, JSONOrient.split)
),
'oneOf': [DataframeSerializer.get_schema(o) for o in JSONOrient]
'oneOf': [DataframeSerializerSync.get_schema(o) for o in JSONOrient]
}
}
}
......@@ -256,7 +256,7 @@ async def write_log_data(
ctx: Context = Depends(get_ctx),
) -> CreateUpdateRecordsResponse:
content = await request.body() # request.stream()
df = DataframeSerializer.read_json(content, orient)
df = await DataframeSerializerAsync().read_json(content, orient)
return await _write_log_data(ctx, persistence, logid, bulk_path, df)
# ---------------------------------------------------------------------------------------------------------------------
......@@ -292,10 +292,10 @@ async def upload_log_data_file(
if mime_type == MimeTypes.JSON:
# TODO for now the entire content is read at once, can chunk it instead I guess
content: bytes = await file.read()
df = DataframeSerializer.read_json(content, orient)
df = await DataframeSerializerAsync().read_json(content, orient)
elif mime_type == MimeTypes.PARQUET:
try:
df = DataframeSerializer.read_parquet(file.file)
df = await DataframeSerializerAsync().read_parquet(file.file)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail='invalid data: ' + e.message if hasattr(e, 'message') else 'unknown error')
......@@ -335,7 +335,7 @@ async def _get_log_data(
log_record = await fetch_record(ctx, logid, version)
df = await persistence.read_bulk(ctx, log_record, bulk_id_path)
content = DataframeSerializer.to_json(df, orient=orient)
content = await DataframeSerializerAsync().to_json(df, orient=orient)
return Response(content=content, media_type=MimeTypes.JSON.type) # content is already jsonified no need to use JSONResponse
......@@ -351,12 +351,12 @@ async def _get_log_data(
'.\n Here\'re examples for data with {} rows and {} columns with different _orient_: '.format(
_log_dataframe_example.shape[0],
_log_dataframe_example.shape[1]) +
''.join([f'\n* {o.value}: <br/>`{DataframeSerializer.to_json(_log_dataframe_example, o)}`<br/>&nbsp;'
''.join([f'\n* {o.value}: <br/>`{DataframeSerializerSync.to_json(_log_dataframe_example, o)}`<br/>&nbsp;'
for o in JSONOrient]),
name='GetLogDataResponse',
example=DataframeSerializer.to_json(_log_dataframe_example, JSONOrient.split),
schema={'oneOf': [DataframeSerializer.get_schema(o) for o in JSONOrient]})
example=DataframeSerializerSync.to_json(_log_dataframe_example, JSONOrient.split),
schema={'oneOf': [DataframeSerializerSync.get_schema(o) for o in JSONOrient]})
])
@router.get('/logs/{logid}/data',
summary="Returns all data within the specified filters. Strongly consistent.",
......@@ -433,12 +433,12 @@ async def get_log_data_statistics(logid: str,
'.\n Here\'re examples for data with {} rows and {} columns with different _orient_: '.format(
_log_dataframe_example.shape[0],
_log_dataframe_example.shape[1]) +
''.join([f'\n* {o.value}: <br/>`{DataframeSerializer.to_json(_log_dataframe_example, o)}`<br/>&nbsp;'
''.join([f'\n* {o.value}: <br/>`{DataframeSerializerSync.to_json(_log_dataframe_example, o)}`<br/>&nbsp;'
for o in JSONOrient]),
name='GetLogDataResponse',
example=DataframeSerializer.to_json(_log_dataframe_example, JSONOrient.split),
schema={'oneOf': [DataframeSerializer.get_schema(o) for o in JSONOrient]})
example=DataframeSerializerSync.to_json(_log_dataframe_example, JSONOrient.split),
schema={'oneOf': [DataframeSerializerSync.get_schema(o) for o in JSONOrient]})
])
@router.get('/logs/{logid}/versions/{version}/data',
summary="Returns all data within the specified filters. Strongly consistent.",
......
......@@ -29,7 +29,8 @@ from app.model.model_curated import (
from app.routers.common_parameters import json_orient_parameter, REQUIRED_ROLES_READ, REQUIRED_ROLES_WRITE
from app.model.model_utils import from_record, to_record
from app.routers.trajectory.persistence import Persistence
from app.bulk_persistence import (DataframeSerializer,
from app.bulk_persistence import (DataframeSerializerSync,
DataframeSerializerAsync,
JSONOrient,
MimeTypes,
NoBulkException,
......@@ -191,7 +192,7 @@ _trajectory_dataframe_example = DataFrame([
_trajectory_dataframe_example.shape[0],
_trajectory_dataframe_example.shape[1],
', '.join(_trajectory_dataframe_example.columns.tolist())) +
''.join([f'\n* {o}: <br/>`{DataframeSerializer.to_json(_trajectory_dataframe_example, o)}`<br/>&nbsp;'
''.join([f'\n* {o}: <br/>`{DataframeSerializerSync.to_json(_trajectory_dataframe_example, o)}`<br/>&nbsp;'
for o in JSONOrient]),
# put examples here because of bug in swagger UI to properly render multiple examples
"required": True,
......@@ -199,12 +200,12 @@ _trajectory_dataframe_example = DataFrame([
MimeTypes.JSON.type: {
"schema": {
# swagger UI bug, so single example here
"example": DataframeSerializer.to_json(
"example": DataframeSerializerSync.to_json(
_trajectory_dataframe_example,
JSONOrient.split
),
"oneOf": [
DataframeSerializer.get_schema(o) for o in JSONOrient
DataframeSerializerSync.get_schema(o) for o in JSONOrient
],
}
}
......@@ -230,7 +231,7 @@ async def post_traj_data(
persistence: Persistence = Depends(get_persistence)) -> CreateUpdateRecordsResponse:
content = await request.body() # request.stream()
df = DataframeSerializer.read_json(content, orient)
df = await DataframeSerializerAsync().read_json(content, orient)
trajectory_record = await fetch_trajectory_record(ctx, trajectoryid)
record = from_record(Trajectory, trajectory_record)
......@@ -300,7 +301,7 @@ async def _get_trajectory_data(
except InvalidBulkException as ex:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(ex))
content = DataframeSerializer.to_json(df, orient=orient)
content = await DataframeSerializerAsync().to_json(df, orient=orient)
return Response(content=content, media_type=MimeTypes.JSON.type)
......@@ -316,12 +317,12 @@ async def _get_trajectory_data(
'.\n Here\'re examples for data with {} rows for channels {} with different _orient_: '.format(
_trajectory_dataframe_example.shape[0],
', '.join(_trajectory_dataframe_example.columns.tolist())) +
''.join([f'\n* {o.value}: <br/>`{DataframeSerializer.to_json(_trajectory_dataframe_example, o)}`<br/>&nbsp;'
''.join([f'\n* {o.value}: <br/>`{DataframeSerializerSync.to_json(_trajectory_dataframe_example, o)}`<br/>&nbsp;'
for o in JSONOrient]),
name="GetLogDataResponse",
example=DataframeSerializer.to_json(_trajectory_dataframe_example, JSONOrient.split),
example=DataframeSerializerSync.to_json(_trajectory_dataframe_example, JSONOrient.split),
schema={
"oneOf": [DataframeSerializer.get_schema(o) for o in JSONOrient]
"oneOf": [DataframeSerializerSync.get_schema(o) for o in JSONOrient]
},
)
],
......@@ -367,12 +368,12 @@ async def get_traj_data(
'.\n Here\'re examples for data with {} rows and {} columns with different _orient_: '.format(
_trajectory_dataframe_example.shape[0],
_trajectory_dataframe_example.shape[1]) +
''.join([f'\n* {o.value}: <br/>`{DataframeSerializer.to_json(_trajectory_dataframe_example, o)}`<br/>&nbsp;'
''.join([f'\n* {o.value}: <br/>`{DataframeSerializerSync.to_json(_trajectory_dataframe_example, o)}`<br/>&nbsp;'
for o in JSONOrient]),
name='GetTrajectoryDataResponse',
example=DataframeSerializer.to_json(_trajectory_dataframe_example, JSONOrient.split),
schema={'oneOf': [DataframeSerializer.get_schema(o) for o in JSONOrient]})
example=DataframeSerializerSync.to_json(_trajectory_dataframe_example, JSONOrient.split),
schema={'oneOf': [DataframeSerializerSync.get_schema(o) for o in JSONOrient]})
])
@router.get('/trajectories/{trajectoryid}/versions/{version}/data',
summary="Returns all data within the specified filters. Strongly consistent.",
......
......@@ -34,11 +34,17 @@ def get_http_client_session(key: str = 'GLOBAL'):
return ClientSession(json_serialize=json.dumps)
POOL_EXECUTOR_MAX_WORKER = 4
def get_pool_executor():
if get_pool_executor._pool is None:
get_pool_executor._pool = concurrent.futures.ProcessPoolExecutor(POOL_EXECUTOR_MAX_WORKER)
return get_pool_executor._pool
get_pool_executor._pool = concurrent.futures.ThreadPoolExecutor()
get_pool_executor._pool = None
def _setup_temp_dir() -> str:
......
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from os import getpid
import asyncio
from time import sleep
from fastapi import FastAPI, Depends
from fastapi.openapi.utils import get_openapi
......@@ -48,7 +51,12 @@ from app.routers.dipset import dipset_ddms_v2, dip_ddms_v2
from app.routers.logrecognition import log_recognition
from app.routers.search import search, fast_search, search_v3, fast_search_v3
from app.clients import StorageRecordServiceClient, SearchServiceClient
from app.utils import get_http_client_session, OpenApiHandler, get_wdms_temp_dir
from app.utils import (
get_http_client_session,
OpenApiHandler,
get_wdms_temp_dir,
get_pool_executor,
POOL_EXECUTOR_MAX_WORKER)
base_app = FastAPI()
......@@ -90,6 +98,12 @@ def hide_router_modules(modules):
rte.include_in_schema = False
def executor_startup_task():
""" This is a dummy task used to startup executors"""
print(f"process {getpid()} started")
sleep(0.2) # to keep executor "busy"
@base_app.on_event("startup")
async def startup_event():
service_name = Config.service_name.value
......@@ -100,6 +114,15 @@ async def startup_event():
MainInjector().configure(app_injector)
wdms_app.trace_exporter = traces.create_exporter(service_name=service_name)
# init executor pool
logger.get_logger().info("Startup process pool executor")
# force to adjust process count now instead of on first demand
pool = get_pool_executor()
loop = asyncio.get_running_loop()
futures = [loop.run_in_executor(pool, executor_startup_task) for _ in range(POOL_EXECUTOR_MAX_WORKER)]
await asyncio.gather(*futures)
if Config.alpha_feature_enabled.value:
enable_alpha_feature()
......
......@@ -32,6 +32,6 @@ osdu-data-ecosystem-search>=0.3.2, <0.4
osdu-core-lib-python-ibm>=0.0.1, <0.1
osdu-core-lib-python-gcp~=1.1.0
osdu-core-lib-python-azure~=1.1.1
osdu-core-lib-python-azure~=1.1.2
osdu-core-lib-python-aws>=0.0.1, <0.1
osdu-core-lib-python~=1.1.0
......@@ -116,14 +116,14 @@ def create_session(env, entity_type: EntityType, record_id: str, overwrite: bool
url = build_base_url(entity_type) + f'/{record_id}/sessions'
runner = build_request(f'create {entity_type} session', 'POST', url,
payload={'mode': 'overwrite' if overwrite else 'update'})
return runner.call(env, assert_status=200).get_response_obj().id
return runner.call(env, assert_status=200, headers={"Content-Type": "application/json"}).get_response_obj().id
def complete_session(env, entity_type: EntityType, record_id: str, session_id: str, commit: bool):
state = "commit" if commit else "abandon"
url = build_base_url(entity_type) + f'/{record_id}/sessions/{session_id}'
runner = build_request(f'{state} session', 'PATCH', url, payload={'state': state})
runner.call(env).assert_ok()
runner.call(env, headers={"Content-Type": "application/json"}).assert_ok()
class ParquetSerializer:
......@@ -293,7 +293,7 @@ def test_get_data_with_column_filter(with_wdms_env):
with create_record(with_wdms_env, entity_type) as record_id:
size = 100
data = generate_df(['MD', 'X', 'Y', 'Z'], range(size))
data = generate_df(['MD', 'X', 'Y', 'Z', '2D[0]', '2D[1]', '2D[2]'], range(size))
data_to_send = serializer.dump(data)
headers = {'Content-Type': serializer.mime_type, 'Accept': serializer.mime_type}
......@@ -303,7 +303,10 @@ def test_get_data_with_column_filter(with_wdms_env):
validation_list = [ # tuple (params, expected_status, expected data)
({"curves": "MD"}, 200, data[['MD']]),
({"curves": "X, Y, Z"}, 200, data[['X', 'Y', 'Z']]),
({"curves": "W, X"}, 200, data[['X']])
({"curves": "W, X"}, 200, data[['X']]),
({"curves": "2D[0]"}, 200, data[['2D[0]']]),
({"curves": "2D[0:1]"}, 200, data[['2D[0]', '2D[1]']]),
({"curves": "2D"}, 200, data[['2D[0]', '2D[1]', '2D[2]']]),
]
for (params, expected_status, expected_data) in validation_list:
......
......@@ -10,16 +10,34 @@ from pandas.testing import assert_frame_equal
@pytest.mark.parametrize("requested, df_columns, expected", [
(["X"], {"X"}, ["X"]),
([], {"X"}, []),
(["X", "Y", "Z"], {"X", "Y"}, ["X", "Y"]),
(["X", "Y", "Z"], {"X", "Y", "Z"}, ["X", "Y", "Z"]),
(["2D"], {"X", "2D[0]", "2D[1]"}, ["2D[0]", "2D[1]"]),
(["2D[0]"], {"X", "2D[0]", "2D[1]"}, ["2D[0]"]),
(["X", "2D"], {"X", "2D[0]", "2D[1]"}, ["X", "2D[0]", "2D[1]"]),
(["X"], {"X"}, ["X"]),
([], {"X"}, []),
(["X", "Y", "Z"], {"X", "Y"}, ["X", "Y"]),
(["X", "Y", "Z"], {"X", "Y", "Z"}, ["X", "Y", "Z"]),
(["2D"], {"X", "2D[0]", "2D[1]"}, ["2D[0]", "2D[1]"]),
(["2D[0:1]"], {"X", "2D[0]", "2D[1]"}, ["2D[0]", "2D[1]"]),
(["2D[1:3]"], {"X", "2D[0]", "2D[1]"}, ["2D[1]"]),
(["2D[1:3]"], {"2D[0]", "2D[1]", "2D[3]"}, ["2D[1]", "2D[3]"]),
(["2D[0]"], {"X", "2D[0]", "2D[1]"}, ["2D[0]"]),
(["X", "2D"], {"X", "2D[0]", "2D[1]"}, ["X", "2D[0]", "2D[1]"]),
(["2D"], {"2D[str]", "2D[0]"}, ["2D[str]", "2D[0]"]),
(["2D[str:0]"], {"2D[str]", "2D[0]"}, []),
([""], {}, []), # empty string
([""], {""}, [""]), # empty string
(["a"], {"A"}, []), # case sensitive
(["X", "X", "X"], {"X", "Y"}, ["X"]), # removes duplication
(["NMR[0:2]", "GR[2:4]"], {"NMR[0]", "NMR[1]", "GR[2]"}, ["NMR[0]", "NMR[1]", "GR[2]"]), # multiple ranges
(["X[0]", "X[0:5]", "X[0:1]"], {"X[0]", "X[1]", "X[2]"}, ["X[0]", "X[1]", "X[2]"]), # removes duplication in overlapping ranges
(["X[0]"], {"X[0]", "X[0][1]"}, ["X[0]", "X[0][1]"]), # ensure that we capture only the last [..]
(["X[0]"], {"X[0][0]", "X[0][1]"}, ["X[0][0]", "X[0][1]"]),
(["X[0:1]"], {"X[0][0]", "X[0][1]"}, ["X[0][0]", "X[0][1]"]),
(["X[1:1]"], {"X[0][0]", "X[0][1]"}, ["X[0][1]"]),
(["Y[4:]"], {"Y[4], Y[5], Y[6]"}, []), # incomplete range?
(["Y[:4]"], {"Y[4], Y[5], Y[6]"}, []), # incomplete range?
(["2D[5]"], {"X", "2D[0]", "2D[1]"}, []),
])
def test_get_matching_column(requested, df_columns, expected):
result = DataFrameRender.get_matching_column(requested, df_columns)
result = DataFrameRender.get_matching_column(requested, set(df_columns))
assert set(result) == set(expected)
......
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from app.bulk_persistence.dataframe_serializer import DataframeSerializer, JSONOrient
from app.bulk_persistence.dataframe_serializer import (DataframeSerializerSync,
DataframeSerializerAsync,
JSONOrient)
from tests.unit.test_utils import temp_directory
import pandas as pd
import json
......@@ -58,7 +60,7 @@ def check_dataframe(df: pd.DataFrame):
@pytest.mark.parametrize("orient", [o for o in JSONOrient])
def test_schema(orient):
assert DataframeSerializer.get_schema(orient)
assert DataframeSerializerSync.get_schema(orient)
@pytest.mark.parametrize("data_dict,orient", [(d, o) for o, d in dataframe_dict.items()])
......@@ -66,7 +68,7 @@ def test_load_from_str_various_orient(data_dict, orient):
print(orient)
dataframe_json = json.dumps(data_dict)
print(dataframe_json)
df = DataframeSerializer.read_json(dataframe_json, orient=orient)
df = DataframeSerializerSync.read_json(dataframe_json, orient=orient)
check_dataframe(df)