diff --git a/provider/partition-azure/src/main/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationService.java b/provider/partition-azure/src/main/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationService.java index 8bddf4abad03da90b805c997690ea9f6e8bfd39f..ed0f6df00daec6a4a34da62a3a95eb4ca09c2a88 100644 --- a/provider/partition-azure/src/main/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationService.java +++ b/provider/partition-azure/src/main/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationService.java @@ -25,6 +25,9 @@ import java.util.Map; @Component public class AuthorizationService implements IAuthorizationService { + private final String AAD_issuer_v1 = "https://sts.windows.net"; + private final String AAD_issuer_v2 = "https://login.microsoftonline.com"; + enum UserType { REGULAR_USER, GUEST_USER, @@ -40,14 +43,24 @@ public class AuthorizationService implements IAuthorizationService { } final UserPrincipal userPrincipal = (UserPrincipal) principal; + String issuer = userPrincipal.getClaim("iss").toString(); UserType type = getType(userPrincipal); - if (type == UserType.SERVICE_PRINCIPAL) { + if (type == UserType.SERVICE_PRINCIPAL && issuedByAAD(issuer)) { return true; } return false; } + /*** + * Check that issuer string startswith accepted prefix of AAD issuer url (V1 or V2). + * @param issuer claim for "issuer" + * @return true if issuer startswith V1 url or V2 url + */ + private boolean issuedByAAD(String issuer) { + return issuer.startsWith(AAD_issuer_v1) || issuer.startsWith(AAD_issuer_v2); + } + /** * The internal method to get the user principal. * diff --git a/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationServiceTest.java b/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationServiceTest.java index f76d96dcc47e2e8f0f8e37fe858ba768a79e88bc..0d732a671c334790714808b69fc07b0ca95e23ba 100644 --- a/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationServiceTest.java +++ b/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/AuthorizationServiceTest.java @@ -103,11 +103,23 @@ public class AuthorizationServiceTest { } @Test - public void shouldReturnTrueWhenAADTokenIsSetInContext() { + public void shouldReturnTrueWhenAADTokenIsSetInContext_AndIssuerIsAAD() { createAADUserPrincipalSetSecurityContext(TestUtils.APPID, TestUtils.getAppId(), TestUtils.getAadIssuer()); assertTrue(authorizationService.isDomainAdminServiceAccount()); } + @Test + public void shouldReturnTrueWhenAADTokenIsSetInContext_AndIssuerIsAADV2() { + createAADUserPrincipalSetSecurityContext(TestUtils.APPID, TestUtils.getAppId(), TestUtils.getAadIssuerV2()); + assertTrue(authorizationService.isDomainAdminServiceAccount()); + } + + @Test + public void shouldReturnFalseWhenAADTokenIsSetInContext_AndIssuerIsNotAAD() { + createAADUserPrincipalSetSecurityContext(TestUtils.APPID, TestUtils.getAppId(), TestUtils.getNonAadIssuer()); + assertFalse(authorizationService.isDomainAdminServiceAccount()); + } + @Getter public class DummyAuthToken { diff --git a/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/TestUtils.java b/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/TestUtils.java index 392665322ef141e5fd31a07da30529980b77ae0a..cfa415a48ae7de425f8b9dd58705baea92260628 100644 --- a/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/TestUtils.java +++ b/provider/partition-azure/src/test/java/org/opengroup/osdu/partition/provider/azure/utils/TestUtils.java @@ -18,7 +18,11 @@ public class TestUtils { private static final String appId = "1234"; public static final String APPID = "appid"; public static final String aadIssuer = "https://sts.windows.net"; + public static final String aadIssuerV2 = "https://login.microsoftonline.com"; + public static final String nonAadIssuer = "https://login.abc.com"; public static String getAppId() {return appId;} public static String getAadIssuer() {return aadIssuer;} + public static String getAadIssuerV2() {return aadIssuerV2;} + public static String getNonAadIssuer() {return nonAadIssuer;} }