Skip to content

Commit 0a00456

Browse files
soondenanaomkreddy
authored andcommitted
MINOR: Few cleanups
Reviewers: Manikumar Reddy <[email protected]>
1 parent 450b707 commit 0a00456

2 files changed

Lines changed: 76 additions & 4 deletions

File tree

clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ public byte[] evaluateResponse(byte[] response) throws SaslException, SaslAuthen
149149
case RECEIVE_CLIENT_FINAL_MESSAGE:
150150
try {
151151
ClientFinalMessage clientFinalMessage = new ClientFinalMessage(response);
152+
if (!clientFinalMessage.nonce().endsWith(serverFirstMessage.nonce())) {
153+
throw new SaslException("Invalid client nonce in the final client message.");
154+
}
152155
verifyClientProof(clientFinalMessage);
153156
byte[] serverKey = scramCredential.serverKey();
154157
byte[] serverSignature = formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, clientFinalMessage);
@@ -222,7 +225,8 @@ private void setState(State state) {
222225
this.state = state;
223226
}
224227

225-
private void verifyClientProof(ClientFinalMessage clientFinalMessage) throws SaslException {
228+
// Visible for testing
229+
void verifyClientProof(ClientFinalMessage clientFinalMessage) throws SaslException {
226230
try {
227231
byte[] expectedStoredKey = scramCredential.storedKey();
228232
byte[] clientSignature = formatter.clientSignature(expectedStoredKey, clientFirstMessage, serverFirstMessage, clientFinalMessage);

clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,26 @@
1717
package org.apache.kafka.common.security.scram.internals;
1818

1919

20-
import java.nio.charset.StandardCharsets;
21-
import java.util.HashMap;
22-
2320
import org.apache.kafka.common.errors.SaslAuthenticationException;
2421
import org.apache.kafka.common.security.authenticator.CredentialCache;
2522
import org.apache.kafka.common.security.scram.ScramCredential;
23+
import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage;
24+
import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage;
25+
import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage;
2626
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
2727

2828
import org.junit.jupiter.api.BeforeEach;
2929
import org.junit.jupiter.api.Test;
30+
import org.mockito.Mockito;
31+
32+
import java.nio.charset.StandardCharsets;
33+
import java.util.Base64;
34+
import java.util.HashMap;
3035

36+
import javax.security.sasl.SaslException;
37+
38+
import static org.junit.jupiter.api.Assertions.assertEquals;
39+
import static org.junit.jupiter.api.Assertions.assertNull;
3140
import static org.junit.jupiter.api.Assertions.assertThrows;
3241
import static org.junit.jupiter.api.Assertions.assertTrue;
3342

@@ -67,10 +76,69 @@ public void authorizationIdNotEqualsAuthenticationId() {
6776
assertThrows(SaslAuthenticationException.class, () -> saslServer.evaluateResponse(clientFirstMessage(USER_A, USER_B)));
6877
}
6978

79+
/**
80+
* Validate that server responds with client's nonce as prefix of its nonce in the
81+
* server first message.
82+
* <br>
83+
* In addition, it checks that the client final message has nonce that it sent in its
84+
* first message.
85+
*/
86+
@Test
87+
public void validateNonceExchange() throws SaslException {
88+
ScramSaslServer spySaslServer = Mockito.spy(saslServer);
89+
byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A);
90+
ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes);
91+
92+
byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes);
93+
ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes);
94+
assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()),
95+
"Nonce in server message should start with client first message's nonce");
96+
97+
byte[] clientFinalMessage = clientFinalMessage(serverFirstMessage.nonce());
98+
Mockito.doNothing()
99+
.when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class));
100+
byte[] serverFinalMsgBytes = spySaslServer.evaluateResponse(clientFinalMessage);
101+
ServerFinalMessage serverFinalMessage = new ServerFinalMessage(serverFinalMsgBytes);
102+
assertNull(serverFinalMessage.error(), "Server final message should not contain error");
103+
}
104+
105+
@Test
106+
public void validateFailedNonceExchange() throws SaslException {
107+
ScramSaslServer spySaslServer = Mockito.spy(saslServer);
108+
byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A);
109+
ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes);
110+
111+
byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes);
112+
ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes);
113+
assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()),
114+
"Nonce in server message should start with client first message's nonce");
115+
116+
byte[] clientFinalMessage = clientFinalMessage(formatter.secureRandomString());
117+
Mockito.doNothing()
118+
.when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class));
119+
SaslException saslException = assertThrows(SaslException.class,
120+
() -> spySaslServer.evaluateResponse(clientFinalMessage));
121+
assertEquals("Invalid client nonce in the final client message.",
122+
saslException.getMessage(),
123+
"Failure message: " + saslException.getMessage());
124+
}
125+
70126
private byte[] clientFirstMessage(String userName, String authorizationId) {
71127
String nonce = formatter.secureRandomString();
72128
String authorizationField = authorizationId != null ? "a=" + authorizationId : "";
73129
String firstMessage = String.format("n,%s,n=%s,r=%s", authorizationField, userName, nonce);
74130
return firstMessage.getBytes(StandardCharsets.UTF_8);
75131
}
132+
133+
private byte[] clientFinalMessage(String nonce) {
134+
String channelBinding = randomBytesAsString();
135+
String proof = randomBytesAsString();
136+
137+
String message = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof);
138+
return message.getBytes(StandardCharsets.UTF_8);
139+
}
140+
141+
private String randomBytesAsString() {
142+
return Base64.getEncoder().encodeToString(formatter.secureRandomBytes());
143+
}
76144
}

0 commit comments

Comments
 (0)