From e91c3337ff440b22c430f02cdb54e313f94f3ef2 Mon Sep 17 00:00:00 2001
From: Larissa Pereira <LPereira14@slb.com>
Date: Mon, 5 Apr 2021 15:37:56 -0500
Subject: [PATCH] add correlation id to response header

---
 .../osdu/indexer/util/IndexFilter.java        | 49 +++++++++++++++++++
 .../osdu/indexer/util/IndexFilterTest.java    | 43 ++++++++++++++++
 .../osdu/indexer/middleware/IndexFilter.java  |  4 +-
 .../osdu/indexer/middleware/IndexFilter.java  |  3 +-
 4 files changed, 94 insertions(+), 5 deletions(-)
 create mode 100644 indexer-core/src/main/java/org/opengroup/osdu/indexer/util/IndexFilter.java
 create mode 100644 indexer-core/src/test/java/org/opengroup/osdu/indexer/util/IndexFilterTest.java

diff --git a/indexer-core/src/main/java/org/opengroup/osdu/indexer/util/IndexFilter.java b/indexer-core/src/main/java/org/opengroup/osdu/indexer/util/IndexFilter.java
new file mode 100644
index 000000000..0acdc656c
--- /dev/null
+++ b/indexer-core/src/main/java/org/opengroup/osdu/indexer/util/IndexFilter.java
@@ -0,0 +1,49 @@
+
+package org.opengroup.osdu.indexer.util;
+
+import java.io.IOException;
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletResponse;
+import lombok.extern.java.Log;
+import org.opengroup.osdu.core.common.model.http.DpsHeaders;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
+
+@Log
+@Component
+public class IndexFilter implements Filter {
+
+    private final DpsHeaders dpsHeaders;
+
+    @Autowired
+    public IndexFilter(DpsHeaders dpsHeaders) {
+        this.dpsHeaders = dpsHeaders;
+    }
+
+    @Override
+    public void init(FilterConfig filterConfig) throws ServletException {
+    }
+
+    @Override
+    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
+                         FilterChain filterChain)
+            throws IOException, ServletException {
+
+        HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
+
+        dpsHeaders.addCorrelationIdIfMissing();
+        httpResponse.addHeader(DpsHeaders.CORRELATION_ID, dpsHeaders.getCorrelationId());
+
+        filterChain.doFilter(servletRequest, servletResponse);
+    }
+
+    @Override
+    public void destroy() {
+    }
+
+}
diff --git a/indexer-core/src/test/java/org/opengroup/osdu/indexer/util/IndexFilterTest.java b/indexer-core/src/test/java/org/opengroup/osdu/indexer/util/IndexFilterTest.java
new file mode 100644
index 000000000..d404aa9dd
--- /dev/null
+++ b/indexer-core/src/test/java/org/opengroup/osdu/indexer/util/IndexFilterTest.java
@@ -0,0 +1,43 @@
+package org.opengroup.osdu.indexer.util;
+
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.opengroup.osdu.core.common.model.http.DpsHeaders;
+
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import java.io.IOException;
+import java.util.Collections;
+
+@RunWith(MockitoJUnitRunner.class)
+public class IndexFilterTest {
+
+    @InjectMocks
+    private IndexFilter indexFilter;
+
+    @Mock
+    private DpsHeaders dpsHeaders;
+
+    @Test
+    public void shouldSetCorrectResponseHeaders() throws IOException, ServletException {
+        HttpServletRequest httpServletRequest = Mockito.mock(HttpServletRequest.class);
+        HttpServletResponse httpServletResponse = Mockito.mock(HttpServletResponse.class);
+        FilterChain filterChain = Mockito.mock(FilterChain.class);
+
+        Mockito.when(dpsHeaders.getCorrelationId()).thenReturn("correlation-id-value");
+
+        indexFilter.doFilter(httpServletRequest, httpServletResponse, filterChain);
+
+        Mockito.verify(httpServletResponse).addHeader("correlation-id", "correlation-id-value");
+        Mockito.verify(filterChain).doFilter(httpServletRequest, httpServletResponse);
+    }
+}
+
diff --git a/provider/indexer-gcp/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java b/provider/indexer-gcp/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java
index cd7d4622c..95b3449d8 100644
--- a/provider/indexer-gcp/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java
+++ b/provider/indexer-gcp/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java
@@ -74,8 +74,6 @@ public class IndexFilter implements Filter {
             checkWorkerApiAccess(requestInfo);
         }
 
-        filterChain.doFilter(servletRequest, servletResponse);
-
         HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
         Map<String, List<Object>> standardHeaders = ResponseHeaders.STANDARD_RESPONSE_HEADERS;
         for (Map.Entry<String, List<Object>> header : standardHeaders.entrySet()) {
@@ -84,7 +82,7 @@ public class IndexFilter implements Filter {
         if (httpResponse.getHeader(DpsHeaders.CORRELATION_ID) == null) {
             httpResponse.addHeader(DpsHeaders.CORRELATION_ID, dpsHeaders.getCorrelationId());
         }
-
+        filterChain.doFilter(servletRequest, servletResponse);
     }
 
     @Override
diff --git a/provider/indexer-reference/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java b/provider/indexer-reference/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java
index 12cffcf01..79bb4e921 100644
--- a/provider/indexer-reference/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java
+++ b/provider/indexer-reference/src/main/java/org/opengroup/osdu/indexer/middleware/IndexFilter.java
@@ -89,8 +89,6 @@ public class IndexFilter implements Filter {
       checkWorkerApiAccess(requestInfo);
     }
 
-    filterChain.doFilter(servletRequest, servletResponse);
-
     HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
     Map<String, List<Object>> standardHeaders = ResponseHeaders.STANDARD_RESPONSE_HEADERS;
     for (Map.Entry<String, List<Object>> header : standardHeaders.entrySet()) {
@@ -99,6 +97,7 @@ public class IndexFilter implements Filter {
     if (httpResponse.getHeader(DpsHeaders.CORRELATION_ID) == null) {
       httpResponse.addHeader(DpsHeaders.CORRELATION_ID, dpsHeaders.getCorrelationId());
     }
+    filterChain.doFilter(servletRequest, servletResponse);
   }
 
   @Override
-- 
GitLab