From b890fe852452169ec3e5b808d45a6b832d74b879 Mon Sep 17 00:00:00 2001
From: Manikanta Swamy Akula <V.ManikantaSwamyAku2@shell.com>
Date: Wed, 29 Nov 2023 20:31:58 +0530
Subject: [PATCH] CRS Catalog query_with_cursor needed changes

---
 .../osdu/crs/api/CrsCatalogApiV3.java         | 10 ++-
 .../crs/model/response/SearchResponse.java    | 23 +++++--
 .../crs/service/SearchWrapperService.java     | 52 +++++++++++++--
 .../catalog_test_core/test_crs_catalog_v3.py  | 64 +++++++++++++------
 4 files changed, 118 insertions(+), 31 deletions(-)

diff --git a/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/api/CrsCatalogApiV3.java b/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/api/CrsCatalogApiV3.java
index d72be229..946c0039 100644
--- a/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/api/CrsCatalogApiV3.java
+++ b/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/api/CrsCatalogApiV3.java
@@ -100,7 +100,11 @@ public class CrsCatalogApiV3 {
 	public SearchResponse getCoordinateTransformations(
 			@RequestBody(required=false) CoordinateTransformationsQuery coordinateTransformationsQuery
 			) {
-		return searchWrapperService.search(coordinateTransformationsQuery, SearchWrapperService.getCoordinateTransformationKind());
+		if(coordinateTransformationsQuery.getOffset()!=null){
+			return searchWrapperService.search(coordinateTransformationsQuery, SearchWrapperService.getCoordinateTransformationKind());
+
+		}
+			return searchWrapperService.searchWithCursor(coordinateTransformationsQuery, SearchWrapperService.getCoordinateTransformationKind());
 	}
 
 
@@ -146,7 +150,11 @@ public class CrsCatalogApiV3 {
 	public SearchResponse getCoordinateReferenceSystems(
 			@RequestBody(required=false) CoordinateReferenceSystemsQuery coordinateReferenceSystemsQuery
 	) {
+	if(coordinateReferenceSystemsQuery.getOffset()!=null){
 		return searchWrapperService.search(coordinateReferenceSystemsQuery, SearchWrapperService.getCoordinateReferenceSystemKind());
+
+	}
+		return searchWrapperService.searchWithCursor(coordinateReferenceSystemsQuery, SearchWrapperService.getCoordinateReferenceSystemKind());
 	}
 
 	@Operation(summary = "${CrsCatalogApiV3.CoordinateReferenceSystems.summary}", description = "${CrsCatalogApiV3.CoordinateReferenceSystems.description}",
diff --git a/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/model/response/SearchResponse.java b/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/model/response/SearchResponse.java
index 97daec86..55275e4b 100644
--- a/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/model/response/SearchResponse.java
+++ b/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/model/response/SearchResponse.java
@@ -17,14 +17,27 @@ package org.opengroup.osdu.crs.model.response;
 import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.AllArgsConstructor;
 import lombok.Data;
+import org.opengroup.osdu.core.common.model.search.CursorQueryResponse;
 import org.opengroup.osdu.core.common.model.search.QueryResponse;
 
 @Data
-@AllArgsConstructor
 @Schema(description = "Results for most V3 endpoints")
 public class SearchResponse {
-	@Schema(description = "Results from Search service")
-	private QueryResponse searchResults;
-	@Schema(description = "Query string used against Search service")
-	private String query;
+    @Schema(description = "Results from Search service")
+    private QueryResponse searchResults;
+    private CursorQueryResponse cursorSearchResults;
+    @Schema(description = "Query string used against Search service")
+    private String query;
+
+
+    public SearchResponse(QueryResponse searchResults, String query) {
+        this.searchResults = searchResults;
+        this.query = query;
+
+    }
+
+    public SearchResponse(CursorQueryResponse cursorSearchResults, String qurey) {
+        this.cursorSearchResults = cursorSearchResults;
+        this.query = query;
+    }
 }
diff --git a/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/service/SearchWrapperService.java b/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/service/SearchWrapperService.java
index b2b39d72..b270e23a 100644
--- a/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/service/SearchWrapperService.java
+++ b/crs-catalog-core/src/main/java/org/opengroup/osdu/crs/service/SearchWrapperService.java
@@ -16,10 +16,7 @@ package org.opengroup.osdu.crs.service;
 
 import org.opengroup.osdu.core.common.logging.JaxRsDpsLog;
 import org.opengroup.osdu.core.common.model.http.DpsHeaders;
-import org.opengroup.osdu.core.common.model.search.QueryRequest;
-import org.opengroup.osdu.core.common.model.search.QueryResponse;
-import org.opengroup.osdu.core.common.model.search.SearchException;
-import org.opengroup.osdu.core.common.model.search.SpatialFilter;
+import org.opengroup.osdu.core.common.model.search.*;
 import org.opengroup.osdu.core.common.search.ISearchFactory;
 import org.opengroup.osdu.core.common.search.ISearchService;
 import org.opengroup.osdu.crs.model.request.ISearchQuery;
@@ -123,7 +120,29 @@ public class SearchWrapperService {
 
         return new SearchResponse(sendToSearch(queryRequest), query);
     }
+    public SearchResponse searchWithCursor(ISearchQuery searchQuery, String kind) {
+        CursorQueryRequest cursorqueryRequest = new CursorQueryRequest ();
 
+        String query = searchQuery.constructQuery();
+        cursorqueryRequest.setQuery(query);
+        cursorqueryRequest.setKind(kind);
+
+        List<String> returnedFields = searchQuery.getReturnedFields();
+        if (returnedFields.size() > 0) {
+            cursorqueryRequest.setReturnedFields(returnedFields);
+        }
+
+        SpatialFilter spatialFilter = searchQuery.constructSpatialFilter();
+        if (spatialFilter != null) {
+            cursorqueryRequest.setSpatialFilter(searchQuery.constructSpatialFilter());
+        }
+
+        if (searchQuery.getLimit() != null) {
+            cursorqueryRequest.setLimit(searchQuery.getLimit());
+        }
+
+        return new SearchResponse(sendToSearchWithCursor(cursorqueryRequest), query);
+    }
     private QueryResponse sendToSearch(QueryRequest queryRequest) {
         QueryResponse queryResponse = null;
         try {
@@ -149,6 +168,31 @@ public class SearchWrapperService {
         }
         return queryResponse;
     }
+    private CursorQueryResponse sendToSearchWithCursor(CursorQueryRequest queryRequest) {
+        CursorQueryResponse cursorqueryResponse = null;
+        try {
+            logger.debug(String.format("Sending query to search service: %s", queryRequest.toString()));
+           
+            cursorqueryResponse = searchService.searchCursor(queryRequest);
+            List<Map<String, Object>> searchResultList = new ArrayList<Map<String, Object>>();
+            searchResultList = cursorqueryResponse.getResults();
+            String cursor_value = cursorqueryResponse.getCursor();
+            int default_Count = searchResultList.size();
+            long totalCount = cursorqueryResponse.getTotalCount();
+            while (default_Count < totalCount && queryRequest.getLimit() > DEFAULT_QUERY_LIMIT) {
+                queryRequest.setCursor(cursor_value);
+                cursorqueryResponse = searchService.searchCursor(queryRequest);
+                searchResultList.addAll(cursorqueryResponse.getResults());
+                default_Count += default_Count;
+
+            }
+            cursorqueryResponse.setResults(searchResultList);
+            logger.debug(String.format("Received response from search service: %s", cursorqueryResponse.toString()));
+        } catch (SearchException e) {
+            handleSearchError("Failed to call search service", e);
+        }
+        return cursorqueryResponse;
+    }
 
     private void handleSearchError(String errorMsg, Exception e) {
         logger.error(errorMsg, e);
diff --git a/testing/catalog_test_core/test_crs_catalog_v3.py b/testing/catalog_test_core/test_crs_catalog_v3.py
index 8b90fbcf..2bf8efaf 100644
--- a/testing/catalog_test_core/test_crs_catalog_v3.py
+++ b/testing/catalog_test_core/test_crs_catalog_v3.py
@@ -128,7 +128,7 @@ class TestCrsCatalog(unittest.TestCase):
             time.sleep(10)
 
     @staticmethod
-    def check_search_response_count(response, expected_count, test_name):
+    def check_get_search_response_count(response, expected_count, test_name):
         response_body = json.loads(response.content)
         assert response.status_code == 200
         response_count = len(response_body["searchResults"]["results"])
@@ -141,12 +141,28 @@ class TestCrsCatalog(unittest.TestCase):
             print(f'Error: Test {test_name} Expects {expected_count} records. Got {response_count} records.')
         assert response_count == expected_count
 
+    @staticmethod
+    def check_search_response_count(response, expected_count, test_name):
+        response_body = json.loads(response.content)
+        assert response.status_code == 200
+
+        response_count = len(response_body["cursorSearchResults"]["results"])
+
+        for test_response in response_body["cursorSearchResults"]["results"]:
+            if test_response["id"] not in record_id_set:
+                response_count -= 1
+
+        if response_count != expected_count:
+            print(f'Error: Test {test_name} Expects {expected_count} records. Got {response_count} records.')
+        assert response_count == expected_count
+
     def test_get_coordinate_transformation_dataId(self):
         with open(f'{self.path}v3/GetCoordinateTransformationTestData.json') as test_data_file:
             test_data = json.loads(test_data_file.read().replace('{{data_partition_id}}', constants.MY_TENANT))
             response = self.client.make_request('GET', f'{ct_endpoint_path}?dataId={test_data["dataId"]}')
             response_body = json.loads(response.content)
-            self.check_search_response_count(response, 1, "test_get_coordinate_transformation_dataId")
+        
+            self.check_get_search_response_count(response, 1, "test_get_coordinate_transformation_dataId")
             assert response_body["searchResults"]["results"][0]["id"] == test_data["recordId"]
             assert response_body["searchResults"]["results"][0]["data"]["ID"] == test_data["dataId"]
             for record_property in ('kind', 'version', 'acl', 'legal', 'namespace'):
@@ -161,7 +177,7 @@ class TestCrsCatalog(unittest.TestCase):
             test_data = json.loads(test_data_file.read().replace('{{data_partition_id}}', constants.MY_TENANT))
             response = self.client.make_request('GET', f'{ct_endpoint_path}?recordId={test_data["recordId"]}')
             response_body = json.loads(response.content)
-            self.check_search_response_count(response, 1, "test_get_coordinate_transformation_recordId")
+            self.check_get_search_response_count(response, 1, "test_get_coordinate_transformation_recordId")
             assert response_body["searchResults"]["results"][0]["id"] == test_data["recordId"]
             assert response_body["searchResults"]["results"][0]["data"]["ID"] == test_data["dataId"]
 
@@ -171,7 +187,7 @@ class TestCrsCatalog(unittest.TestCase):
             response = self.client.make_request('GET',
                                                 f'{ct_endpoint_path}?dataId={test_data["dataId"]}&recordId={test_data["recordId"]}')
             response_body = json.loads(response.content)
-            self.check_search_response_count(response, 1, "test_get_coordinate_transformation_dataId_recordId")
+            self.check_get_search_response_count(response, 1, "test_get_coordinate_transformation_dataId_recordId")
             assert response_body["searchResults"]["results"][0]["id"] == test_data["recordId"]
             assert response_body["searchResults"]["results"][0]["data"]["ID"] == test_data["dataId"]
 
@@ -180,7 +196,8 @@ class TestCrsCatalog(unittest.TestCase):
             test_data = json.loads(test_data_file.read().replace('{{data_partition_id}}', constants.MY_TENANT))
             response = self.client.make_request('GET', f'{crs_endpoint_path}?dataId={test_data["dataId"]}')
             response_body = json.loads(response.content)
-            self.check_search_response_count(response, 1, "test_get_coordinate_reference_system_dataId")
+    
+            self.check_get_search_response_count(response, 1, "test_get_coordinate_reference_system_dataId")
             assert response_body["searchResults"]["results"][0]["id"] == test_data["recordId"]
             assert response_body["searchResults"]["results"][0]["data"]["ID"] == test_data["dataId"]
             for record_property in ('kind', 'version', 'acl', 'legal', 'namespace'):
@@ -195,7 +212,7 @@ class TestCrsCatalog(unittest.TestCase):
             test_data = json.loads(test_data_file.read().replace('{{data_partition_id}}', constants.MY_TENANT))
             response = self.client.make_request('GET', f'{crs_endpoint_path}?recordId={test_data["recordId"]}')
             response_body = json.loads(response.content)
-            self.check_search_response_count(response, 1, "test_get_coordinate_reference_system_recordId")
+            self.check_get_search_response_count(response, 1, "test_get_coordinate_reference_system_recordId")
             assert response_body["searchResults"]["results"][0]["id"] == test_data["recordId"]
             assert response_body["searchResults"]["results"][0]["data"]["ID"] == test_data["dataId"]
 
@@ -205,7 +222,7 @@ class TestCrsCatalog(unittest.TestCase):
             response = self.client.make_request('GET',
                                                 f'{crs_endpoint_path}?dataId={test_data["dataId"]}&recordId={test_data["recordId"]}')
             response_body = json.loads(response.content)
-            self.check_search_response_count(response, 1, "test_get_coordinate_reference_system_dataId_recordId")
+            self.check_get_search_response_count(response, 1, "test_get_coordinate_reference_system_dataId_recordId")
             assert response_body["searchResults"]["results"][0]["id"] == test_data["recordId"]
             assert response_body["searchResults"]["results"][0]["data"]["ID"] == test_data["dataId"]
 
@@ -228,7 +245,8 @@ class TestCrsCatalog(unittest.TestCase):
     def test_search_coordinate_transformations_with_common_partial_name(self):
         with open(f'{self.path}v3/SearchCoordinateTransformationsWithCommonPartialName.json') as test_data:
             response = self.client.make_request('POST', ct_endpoint_path, test_data)
-            self.check_search_response_count(response, 2, "test_search_coordinate_transformations_with_common_partial_name")
+            self.check_search_response_count(response, 2,
+                                             "test_search_coordinate_transformations_with_common_partial_name")
 
     def test_search_coordinate_transformations_with_wrong_name(self):
         with open(f'{self.path}v3/SearchCoordinateTransformationsWithWrongName.json') as test_data:
@@ -239,14 +257,15 @@ class TestCrsCatalog(unittest.TestCase):
         with open(f'{self.path}v3/SearchCoordinateTransformationsReverseSourceAndTarget.json') as test_data_file:
             test_data = test_data_file.read().replace('{{data_partition_id}}', constants.MY_TENANT)
             response = self.client.make_request('POST', ct_endpoint_path, test_data)
-            self.check_search_response_count(response, 1, "test_search_coordinate_transformations_with_reversed_source_and_target_crs")
+            self.check_search_response_count(response, 1,
+                                             "test_search_coordinate_transformations_with_reversed_source_and_target_crs")
 
     def test_search_coordinate_transformations_find_horizontal(self):
         test_data = "{}"
         response = self.client.make_request('POST', ct_endpoint_path, test_data)
         response_body = json.loads(response.content)
         self.check_search_response_count(response, 1, "test_search_coordinate_transformations_find_horizontal")
-        assert response_body["searchResults"]["results"][0]["data"]["Code"] == "1111"
+        assert response_body["cursorSearchResults"]["results"][0]["data"]["Code"] == "1111"
 
     def test_search_coordinate_transformations_find_vertical(self):
         with open(f'{self.path}v3/SearchCoordinateTransformationsVertical.json') as test_data_file:
@@ -254,21 +273,22 @@ class TestCrsCatalog(unittest.TestCase):
             response = self.client.make_request('POST', ct_endpoint_path, test_data)
             response_body = json.loads(response.content)
             self.check_search_response_count(response, 1, "test_search_coordinate_transformations_find_vertical")
-            assert response_body["searchResults"]["results"][0]["data"]["Code"] == "5429"
+            assert response_body["cursorSearchResults"]["results"][0]["data"]["Code"] == "5429"
             for record_property in ('Code', 'Name', 'Kind', 'InactiveIndicator', 'CodeSpace', 'PreferredUsage.Name',
                                     'PreferredUsage.Extent.Description', 'PreferredUsage.Extent.Name',
                                     'CoordinateTransformationType', 'SourceCRS.Name', 'SourceCRS.AuthorityCode.Code',
                                     'TargetCRS.Name', 'TargetCRS.AuthorityCode.Code'):
-                assert record_property in response_body["searchResults"]["results"][0]['data']
-            assert "InformationSource" not in response_body["searchResults"]["results"][0]['data']
+                assert record_property in response_body["cursorSearchResults"]["results"][0]['data']
+            assert "InformationSource" not in response_body["cursorSearchResults"]["results"][0]['data']
 
     def test_search_coordinate_transformations_find_vertical_return_all_fields(self):
         with open(f'{self.path}v3/SearchCoordinateTransformationsVerticalReturnAllFields.json') as test_data_file:
             test_data = test_data_file.read()
             response = self.client.make_request('POST', ct_endpoint_path, test_data)
             response_body = json.loads(response.content)
-            self.check_search_response_count(response, 1, "test_search_coordinate_transformations_find_vertical_return_all_fields")
-            assert "InformationSource" in response_body["searchResults"]["results"][0]['data']
+            self.check_search_response_count(response, 1,
+                                             "test_search_coordinate_transformations_find_vertical_return_all_fields")
+            assert "InformationSource" in response_body["cursorSearchResults"]["results"][0]['data']
 
     def test_search_coordinate_transformations_find_all(self):
         with open(f'{self.path}v3/SearchCoordinateTransformationsAllKinds.json') as test_data_file:
@@ -280,7 +300,8 @@ class TestCrsCatalog(unittest.TestCase):
         with open(f'{self.path}v3/SearchCoordinateTransformationsIncludeDeprecated.json') as test_data_file:
             test_data = test_data_file.read()
             response = self.client.make_request('POST', ct_endpoint_path, test_data)
-            self.check_search_response_count(response, 3, "test_search_coordinate_transformations_find_all_include_deprecated")
+            self.check_search_response_count(response, 3,
+                                             "test_search_coordinate_transformations_find_all_include_deprecated")
 
     def test_search_coordinate_reference_systems(self):
         with open(f'{self.path}v3/SearchCoordinateReferenceSystems.json') as test_data_file:
@@ -299,19 +320,20 @@ class TestCrsCatalog(unittest.TestCase):
             response = self.client.make_request('POST', crs_endpoint_path, test_data)
             self.check_search_response_count(response, 1, "test_search_coordinate_reference_systems_bound_projected")
             response_body = json.loads(response.content)
-            assert response_body["searchResults"]["results"][0]["data"]["Code"] == "24600001"
+            assert response_body["cursorSearchResults"]["results"][0]["data"]["Code"] == "24600001"
             for record_property in ('Code', 'Name', 'Kind', 'InactiveIndicator', 'CodeSpace', 'PreferredUsage.Name',
                                     'PreferredUsage.Extent.Description', 'PreferredUsage.Scope.Name',
                                     'CoordinateSystem.Name', 'Datum.Name', 'RevisionDate'):
-                assert record_property in response_body["searchResults"]["results"][0]['data']
-            assert "InformationSource" not in response_body["searchResults"]["results"][0]['data']
+                assert record_property in response_body["cursorSearchResults"]["results"][0]['data']
+            assert "InformationSource" not in response_body["cursorSearchResults"]["results"][0]['data']
 
     def test_search_coordinate_reference_systems_bound_projected_return_all_fields(self):
         with open(f'{self.path}v3/SearchCoordinateReferenceSystemsBoundProjectedReturnAllFields.json') as test_data:
             response = self.client.make_request('POST', crs_endpoint_path, test_data)
-            self.check_search_response_count(response, 1, "test_search_coordinate_reference_systems_bound_projected_return_all_fields")
+            self.check_search_response_count(response, 1,
+                                             "test_search_coordinate_reference_systems_bound_projected_return_all_fields")
             response_body = json.loads(response.content)
-            assert "InformationSource" in response_body["searchResults"]["results"][0]['data']
+            assert "InformationSource" in response_body["cursorSearchResults"]["results"][0]['data']
 
     def test_search_coordinate_reference_systems_bound_geographic2d(self):
         with open(f'{self.path}v3/SearchCoordinateReferenceSystemsBoundGeographic2D.json') as test_data:
-- 
GitLab