Commit d3f573a5 authored by Jeremie Hallal's avatar Jeremie Hallal
Browse files

Implement columns selection patterns

parent 169a2b09
......@@ -15,6 +15,7 @@
import asyncio
from typing import List, Set, Optional
import re
from contextlib import suppress
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
......@@ -90,26 +91,33 @@ class DataFrameRender:
driver = await with_dask_blob_storage()
return await driver.client.submit(lambda: len(df.index))
re_1D_curve_selection = re.compile(r'\[(?P<index>[0-9]+)\]$')
re_2D_curve_selection = re.compile(r'\[(?P<range>[0-9]+:[0-9]+)\]$')
re_array_selection = re.compile(r'^(?P<name>.+)\[(?P<start>[^:]+):?(?P<stop>.*)\]$')
def _col_matching(sel, col):
if sel == col: # exact match
return True
m_col = DataFrameRender.re_array_selection.match(col)
if not m_col: # if the column doesn't have an array pattern (col[*])
return False
# compare selection with curve name without array suffix [*]
if sel == m_col['name']: # if selection is 'c', c[*] should match
return True
# range selection use cases c[0:2] should match c[0], c[1] and c[2]
m_sel = DataFrameRender.re_array_selection.match(sel)
if m_sel and m_sel['stop']:
with suppress(ValueError): # suppress int conversion exceptions
if int(m_sel['start']) <= int(m_col['start']) <= int(m_sel['stop']):
return True
return False
def get_matching_column(selection: List[str], cols: Set[str]) -> List[str]:
selected = set()
for to_find in selection:
m =
if m:
r = range(*map(int, m['range'].split(':')))
def is_matching(c):
if c == to_find:
return True
i =
return i and int(i['index']) in r
def is_matching(c):
return c == to_find or to_find == DataFrameRender.re_1D_curve_selection.sub('', c)
selected.update(filter(is_matching, cols.difference(selected)))
for sel in selection:
selected.update(filter(lambda col: DataFrameRender._col_matching(sel, col),
return list(selected)
......@@ -116,14 +116,14 @@ def create_session(env, entity_type: EntityType, record_id: str, overwrite: bool
url = build_base_url(entity_type) + f'/{record_id}/sessions'
runner = build_request(f'create {entity_type} session', 'POST', url,
payload={'mode': 'overwrite' if overwrite else 'update'})
return, assert_status=200).get_response_obj().id
return, assert_status=200, headers={"Content-Type": "application/json"}).get_response_obj().id
def complete_session(env, entity_type: EntityType, record_id: str, session_id: str, commit: bool):
state = "commit" if commit else "abandon"
url = build_base_url(entity_type) + f'/{record_id}/sessions/{session_id}'
runner = build_request(f'{state} session', 'PATCH', url, payload={'state': state}), headers={"Content-Type": "application/json"}).assert_ok()
class ParquetSerializer:
......@@ -293,7 +293,7 @@ def test_get_data_with_column_filter(with_wdms_env):
with create_record(with_wdms_env, entity_type) as record_id:
size = 100
data = generate_df(['MD', 'X', 'Y', 'Z'], range(size))
data = generate_df(['MD', 'X', 'Y', 'Z', '2D[0]', '2D[1]', '2D[2]'], range(size))
data_to_send = serializer.dump(data)
headers = {'Content-Type': serializer.mime_type, 'Accept': serializer.mime_type}
......@@ -303,7 +303,10 @@ def test_get_data_with_column_filter(with_wdms_env):
validation_list = [ # tuple (params, expected_status, expected data)
({"curves": "MD"}, 200, data[['MD']]),
({"curves": "X, Y, Z"}, 200, data[['X', 'Y', 'Z']]),
({"curves": "W, X"}, 200, data[['X']])
({"curves": "W, X"}, 200, data[['X']]),
({"curves": "2D[0]"}, 200, data[['2D[0]']]),
({"curves": "2D[0:1]"}, 200, data[['2D[0]', '2D[1]']]),
({"curves": "2D"}, 200, data[['2D[0]', '2D[1]', '2D[2]']]),
for (params, expected_status, expected_data) in validation_list:
......@@ -10,16 +10,34 @@ from pandas.testing import assert_frame_equal
@pytest.mark.parametrize("requested, df_columns, expected", [
(["X"], {"X"}, ["X"]),
([], {"X"}, []),
(["X", "Y", "Z"], {"X", "Y"}, ["X", "Y"]),
(["X", "Y", "Z"], {"X", "Y", "Z"}, ["X", "Y", "Z"]),
(["2D"], {"X", "2D[0]", "2D[1]"}, ["2D[0]", "2D[1]"]),
(["2D[0]"], {"X", "2D[0]", "2D[1]"}, ["2D[0]"]),
(["X", "2D"], {"X", "2D[0]", "2D[1]"}, ["X", "2D[0]", "2D[1]"]),
(["X"], {"X"}, ["X"]),
([], {"X"}, []),
(["X", "Y", "Z"], {"X", "Y"}, ["X", "Y"]),
(["X", "Y", "Z"], {"X", "Y", "Z"}, ["X", "Y", "Z"]),
(["2D"], {"X", "2D[0]", "2D[1]"}, ["2D[0]", "2D[1]"]),
(["2D[0:1]"], {"X", "2D[0]", "2D[1]"}, ["2D[0]", "2D[1]"]),
(["2D[1:3]"], {"X", "2D[0]", "2D[1]"}, ["2D[1]"]),
(["2D[1:3]"], {"2D[0]", "2D[1]", "2D[3]"}, ["2D[1]", "2D[3]"]),
(["2D[0]"], {"X", "2D[0]", "2D[1]"}, ["2D[0]"]),
(["X", "2D"], {"X", "2D[0]", "2D[1]"}, ["X", "2D[0]", "2D[1]"]),
(["2D"], {"2D[str]", "2D[0]"}, ["2D[str]", "2D[0]"]),
(["2D[str:0]"], {"2D[str]", "2D[0]"}, []),
([""], {}, []), # empty string
([""], {""}, [""]), # empty string
(["a"], {"A"}, []), # case sensitive
(["X", "X", "X"], {"X", "Y"}, ["X"]), # removes duplication
(["NMR[0:2]", "GR[2:4]"], {"NMR[0]", "NMR[1]", "GR[2]"}, ["NMR[0]", "NMR[1]", "GR[2]"]), # multiple ranges
(["X[0]", "X[0:5]", "X[0:1]"], {"X[0]", "X[1]", "X[2]"}, ["X[0]", "X[1]", "X[2]"]), # removes duplication in overlapping ranges
(["X[0]"], {"X[0]", "X[0][1]"}, ["X[0]", "X[0][1]"]), # ensure that we capture only the last [..]
(["X[0]"], {"X[0][0]", "X[0][1]"}, ["X[0][0]", "X[0][1]"]),
(["X[0:1]"], {"X[0][0]", "X[0][1]"}, ["X[0][0]", "X[0][1]"]),
(["X[1:1]"], {"X[0][0]", "X[0][1]"}, ["X[0][1]"]),
(["Y[4:]"], {"Y[4], Y[5], Y[6]"}, []), # incomplete range?
(["Y[:4]"], {"Y[4], Y[5], Y[6]"}, []), # incomplete range?
(["2D[5]"], {"X", "2D[0]", "2D[1]"}, []),
def test_get_matching_column(requested, df_columns, expected):
result = DataFrameRender.get_matching_column(requested, df_columns)
result = DataFrameRender.get_matching_column(requested, set(df_columns))
assert set(result) == set(expected)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment