|
17 | 17 | package org.apache.kafka.common.security.scram.internals; |
18 | 18 |
|
19 | 19 |
|
20 | | -import java.nio.charset.StandardCharsets; |
21 | | -import java.util.HashMap; |
22 | | - |
23 | 20 | import org.apache.kafka.common.errors.SaslAuthenticationException; |
24 | 21 | import org.apache.kafka.common.security.authenticator.CredentialCache; |
25 | 22 | import org.apache.kafka.common.security.scram.ScramCredential; |
| 23 | +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage; |
| 24 | +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage; |
| 25 | +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage; |
26 | 26 | import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; |
27 | 27 |
|
28 | 28 | import org.junit.jupiter.api.BeforeEach; |
29 | 29 | import org.junit.jupiter.api.Test; |
| 30 | +import org.mockito.Mockito; |
| 31 | + |
| 32 | +import java.nio.charset.StandardCharsets; |
| 33 | +import java.util.Base64; |
| 34 | +import java.util.HashMap; |
30 | 35 |
|
| 36 | +import javax.security.sasl.SaslException; |
| 37 | + |
| 38 | +import static org.junit.jupiter.api.Assertions.assertEquals; |
| 39 | +import static org.junit.jupiter.api.Assertions.assertNull; |
31 | 40 | import static org.junit.jupiter.api.Assertions.assertThrows; |
32 | 41 | import static org.junit.jupiter.api.Assertions.assertTrue; |
33 | 42 |
|
@@ -67,10 +76,69 @@ public void authorizationIdNotEqualsAuthenticationId() { |
67 | 76 | assertThrows(SaslAuthenticationException.class, () -> saslServer.evaluateResponse(clientFirstMessage(USER_A, USER_B))); |
68 | 77 | } |
69 | 78 |
|
| 79 | + /** |
| 80 | + * Validate that server responds with client's nonce as prefix of its nonce in the |
| 81 | + * server first message. |
| 82 | + * <br> |
| 83 | + * In addition, it checks that the client final message has nonce that it sent in its |
| 84 | + * first message. |
| 85 | + */ |
| 86 | + @Test |
| 87 | + public void validateNonceExchange() throws SaslException { |
| 88 | + ScramSaslServer spySaslServer = Mockito.spy(saslServer); |
| 89 | + byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A); |
| 90 | + ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes); |
| 91 | + |
| 92 | + byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes); |
| 93 | + ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes); |
| 94 | + assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()), |
| 95 | + "Nonce in server message should start with client first message's nonce"); |
| 96 | + |
| 97 | + byte[] clientFinalMessage = clientFinalMessage(serverFirstMessage.nonce()); |
| 98 | + Mockito.doNothing() |
| 99 | + .when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class)); |
| 100 | + byte[] serverFinalMsgBytes = spySaslServer.evaluateResponse(clientFinalMessage); |
| 101 | + ServerFinalMessage serverFinalMessage = new ServerFinalMessage(serverFinalMsgBytes); |
| 102 | + assertNull(serverFinalMessage.error(), "Server final message should not contain error"); |
| 103 | + } |
| 104 | + |
| 105 | + @Test |
| 106 | + public void validateFailedNonceExchange() throws SaslException { |
| 107 | + ScramSaslServer spySaslServer = Mockito.spy(saslServer); |
| 108 | + byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A); |
| 109 | + ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes); |
| 110 | + |
| 111 | + byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes); |
| 112 | + ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes); |
| 113 | + assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()), |
| 114 | + "Nonce in server message should start with client first message's nonce"); |
| 115 | + |
| 116 | + byte[] clientFinalMessage = clientFinalMessage(formatter.secureRandomString()); |
| 117 | + Mockito.doNothing() |
| 118 | + .when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class)); |
| 119 | + SaslException saslException = assertThrows(SaslException.class, |
| 120 | + () -> spySaslServer.evaluateResponse(clientFinalMessage)); |
| 121 | + assertEquals("Invalid client nonce in the final client message.", |
| 122 | + saslException.getMessage(), |
| 123 | + "Failure message: " + saslException.getMessage()); |
| 124 | + } |
| 125 | + |
70 | 126 | private byte[] clientFirstMessage(String userName, String authorizationId) { |
71 | 127 | String nonce = formatter.secureRandomString(); |
72 | 128 | String authorizationField = authorizationId != null ? "a=" + authorizationId : ""; |
73 | 129 | String firstMessage = String.format("n,%s,n=%s,r=%s", authorizationField, userName, nonce); |
74 | 130 | return firstMessage.getBytes(StandardCharsets.UTF_8); |
75 | 131 | } |
| 132 | + |
| 133 | + private byte[] clientFinalMessage(String nonce) { |
| 134 | + String channelBinding = randomBytesAsString(); |
| 135 | + String proof = randomBytesAsString(); |
| 136 | + |
| 137 | + String message = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof); |
| 138 | + return message.getBytes(StandardCharsets.UTF_8); |
| 139 | + } |
| 140 | + |
| 141 | + private String randomBytesAsString() { |
| 142 | + return Base64.getEncoder().encodeToString(formatter.secureRandomBytes()); |
| 143 | + } |
76 | 144 | } |
0 commit comments