Skip to content

Commit 885a18a

Browse files
authored
Correctly calculate the memory address even if ByteBuf might not have… (java-native-access#512)
… a memory adress that is directly accessible. Motivation: When we try to obtain a memory address of a ByteBuf we might need to fallback to obain it via the ByteBuffer. When doing so we also need to take the position of the ByteBuffer into account as it might be non-zero. Modifications: - Add new static helper methods to Quiche class that can be used to obtain the reader / writer memory address in a consistent / easy manner. - Use these new methods Result: Always use the correct memory address and so don't end up with data-corruption
1 parent 4f5dea5 commit 885a18a

4 files changed

Lines changed: 68 additions & 46 deletions

File tree

codec-classes-quic/src/main/java/io/netty/incubator/codec/quic/QuicHeaderParser.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,25 @@ public void close() {
9898
public void parse(InetSocketAddress sender,
9999
InetSocketAddress recipient, ByteBuf packet, QuicHeaderProcessor callback) throws Exception {
100100
if (closed) {
101-
throw new IllegalStateException("QuicHeaderParser is already closed");
101+
throw new IllegalStateException(QuicHeaderParser.class.getSimpleName() + " is already closed");
102102
}
103-
long contentAddress = Quiche.memoryAddress(packet) + packet.readerIndex();
104-
int contentReadable = packet.readableBytes();
105103

106-
// Ret various len values so quiche_header_info can make use of these.
104+
// Set various len values so quiche_header_info can make use of these.
107105
scidLenBuffer.setInt(0, Quiche.QUICHE_MAX_CONN_ID_LEN);
108106
dcidLenBuffer.setInt(0, Quiche.QUICHE_MAX_CONN_ID_LEN);
109107
tokenLenBuffer.setInt(0, maxTokenLength);
110108

111-
int res = Quiche.quiche_header_info(contentAddress, contentReadable, localConnectionIdLength,
112-
Quiche.memoryAddress(versionBuffer), Quiche.memoryAddress(typeBuffer),
113-
Quiche.memoryAddress(scidBuffer), Quiche.memoryAddress(scidLenBuffer),
114-
Quiche.memoryAddress(dcidBuffer), Quiche.memoryAddress(dcidLenBuffer),
115-
Quiche.memoryAddress(tokenBuffer), Quiche.memoryAddress(tokenLenBuffer));
109+
int res = Quiche.quiche_header_info(
110+
Quiche.readerMemoryAddress(packet), packet.readableBytes(),
111+
localConnectionIdLength,
112+
Quiche.memoryAddress(versionBuffer, 0, versionBuffer.capacity()),
113+
Quiche.memoryAddress(typeBuffer, 0, versionBuffer.capacity()),
114+
Quiche.memoryAddress(scidBuffer, 0, scidBuffer.capacity()),
115+
Quiche.memoryAddress(scidLenBuffer, 0, scidLenBuffer.capacity()),
116+
Quiche.memoryAddress(dcidBuffer, 0, dcidBuffer.capacity()),
117+
Quiche.memoryAddress(dcidLenBuffer, 0, dcidLenBuffer.capacity()),
118+
Quiche.memoryAddress(tokenBuffer, 0, tokenBuffer.capacity()),
119+
Quiche.writerMemoryAddress(tokenLenBuffer));
116120
if (res >= 0) {
117121
int version = versionBuffer.getInt(0);
118122
byte type = typeBuffer.getByte(0);

codec-classes-quic/src/main/java/io/netty/incubator/codec/quic/Quiche.java

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -670,24 +670,42 @@ static native void quiche_config_enable_dgram(long configAddr, boolean enable,
670670
static native int sockaddr_cmp(long addr, long addr2);
671671

672672
/**
673-
* Returns the memory address if the {@link ByteBuf}
673+
* Returns the memory address if the {@link ByteBuf} taking the readerIndex into account.
674+
*
675+
* @param buf the {@link ByteBuf} of which we want to obtain the memory address
676+
* (taking its {@link ByteBuf#readerIndex()} into account).
677+
* @return the memory address of this {@link ByteBuf}s readerIndex.
674678
*/
675-
static long memoryAddress(ByteBuf buf) {
676-
assert buf.isDirect();
677-
return buf.hasMemoryAddress() ? buf.memoryAddress() :
678-
buffer_memory_address(buf.internalNioBuffer(buf.readerIndex(), buf.readableBytes()));
679+
static long readerMemoryAddress(ByteBuf buf) {
680+
return memoryAddress(buf, buf.readerIndex(), buf.readableBytes());
679681
}
680682

681683
/**
682-
* Returns the memory address of the given {@link ByteBuffer}. If you want to also respect the
683-
* {@link ByteBuffer#position()} use {@link #memoryAddressWithPosition(ByteBuffer)}.
684+
* Returns the memory address if the {@link ByteBuf} taking the writerIndex into account.
684685
*
685-
* @param buf the {@link ByteBuffer} of which we want to obtain the memory address..
686-
* @return the memory address of this {@link ByteBuffer}.
686+
* @param buf the {@link ByteBuf} of which we want to obtain the memory address
687+
* (taking its {@link ByteBuf#writerIndex()} into account).
688+
* @return the memory address of this {@link ByteBuf}s writerIndex.
687689
*/
688-
static long memoryAddress(ByteBuffer buf) {
690+
static long writerMemoryAddress(ByteBuf buf) {
691+
return memoryAddress(buf, buf.writerIndex(), buf.writableBytes());
692+
}
693+
694+
/**
695+
* Returns the memory address if the {@link ByteBuf} taking the offset into account.
696+
*
697+
* @param buf the {@link ByteBuf} of which we want to obtain the memory address
698+
* (taking the {@code offset} into account).
699+
* @param offset the offset of the memory address.
700+
* @param len the length of the {@link ByteBuf}.
701+
* @return the memory address of this {@link ByteBuf}s offset.
702+
*/
703+
static long memoryAddress(ByteBuf buf, int offset, int len) {
689704
assert buf.isDirect();
690-
return buffer_memory_address(buf);
705+
if (buf.hasMemoryAddress()) {
706+
return buf.memoryAddress() + offset;
707+
}
708+
return memoryAddressWithPosition(buf.internalNioBuffer(offset, len));
691709
}
692710

693711
/**
@@ -699,7 +717,8 @@ static long memoryAddress(ByteBuffer buf) {
699717
* @return the memory address of this {@link ByteBuffer}s position.
700718
*/
701719
static long memoryAddressWithPosition(ByteBuffer buf) {
702-
return memoryAddress(buf) + buf.position();
720+
assert buf.isDirect();
721+
return buffer_memory_address(buf) + buf.position();
703722
}
704723

705724
@SuppressWarnings("deprecation")

codec-classes-quic/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ private void connect(Function<QuicChannel, ? extends QuicSslEngine> engineProvid
359359
int fromSockaddrLen = SockaddrIn.setAddress(fromSockaddrMemory, local);
360360
int toSockaddrLen = SockaddrIn.setAddress(toSockaddrMemory, remote);
361361
QuicheQuicConnection connection = quicheEngine.createConnection(ssl ->
362-
Quiche.quiche_conn_new_with_tls(Quiche.memoryAddress(idBuffer) + idBuffer.readerIndex(),
362+
Quiche.quiche_conn_new_with_tls(Quiche.readerMemoryAddress(idBuffer),
363363
idBuffer.readableBytes(), -1, -1,
364364
Quiche.memoryAddressWithPosition(fromSockaddrMemory), fromSockaddrLen,
365365
Quiche.memoryAddressWithPosition(toSockaddrMemory), toSockaddrLen,
@@ -590,7 +590,7 @@ protected void doClose() throws Exception {
590590

591591
failPendingConnectPromise();
592592
Quiche.throwIfError(Quiche.quiche_conn_close(connectionAddressChecked(), app, err,
593-
Quiche.memoryAddress(reason) + reason.readerIndex(), reason.readableBytes()));
593+
Quiche.readerMemoryAddress(reason), reason.readableBytes()));
594594

595595
// As we called quiche_conn_close(...) we need to ensure we will call quiche_conn_send(...) either
596596
// now or we will do so once we see the channelReadComplete event.
@@ -695,7 +695,7 @@ protected void doWrite(ChannelOutboundBuffer channelOutboundBuffer) throws Excep
695695

696696
private int sendDatagram(ByteBuf buf) throws ClosedChannelException {
697697
return Quiche.quiche_conn_dgram_send(connectionAddressChecked(),
698-
Quiche.memoryAddress(buf) + buf.readerIndex(), buf.readableBytes());
698+
Quiche.readerMemoryAddress(buf), buf.readableBytes());
699699
}
700700

701701
@Override
@@ -854,7 +854,7 @@ void connectionSendAndFlush() {
854854

855855
private int streamSend0(long streamId, ByteBuf buffer, boolean fin) throws ClosedChannelException {
856856
return Quiche.quiche_conn_stream_send(connectionAddressChecked(), streamId,
857-
Quiche.memoryAddress(buffer) + buffer.readerIndex(), buffer.readableBytes(), fin);
857+
Quiche.readerMemoryAddress(buffer), buffer.readableBytes(), fin);
858858
}
859859

860860
private int streamSend(long streamId, ByteBuffer buffer, boolean fin) throws ClosedChannelException {
@@ -867,14 +867,13 @@ StreamRecvResult streamRecv(long streamId, ByteBuf buffer) throws Exception {
867867
finBuffer = alloc().directBuffer(1);
868868
}
869869
int writerIndex = buffer.writerIndex();
870-
long memoryAddress = Quiche.memoryAddress(buffer);
871870
int recvLen = Quiche.quiche_conn_stream_recv(connectionAddressChecked(), streamId,
872-
memoryAddress + writerIndex, buffer.writableBytes(), Quiche.memoryAddress(finBuffer));
871+
Quiche.writerMemoryAddress(buffer), buffer.writableBytes(), Quiche.writerMemoryAddress(finBuffer));
873872
if (Quiche.throwIfError(recvLen)) {
874873
return StreamRecvResult.DONE;
875-
} else {
876-
buffer.writerIndex(writerIndex + recvLen);
877874
}
875+
876+
buffer.writerIndex(writerIndex + recvLen);
878877
return finBuffer.getBoolean(0) ? StreamRecvResult.FIN : StreamRecvResult.OK;
879878
}
880879

@@ -1052,7 +1051,7 @@ private boolean connectionSendSegments(SegmentedDatagramPacketAllocator segmente
10521051
boolean done;
10531052
int writerIndex = out.writerIndex();
10541053
int written = Quiche.quiche_conn_send(
1055-
connAddr, Quiche.memoryAddress(out) + writerIndex, out.writableBytes(),
1054+
connAddr, Quiche.writerMemoryAddress(out), out.writableBytes(),
10561055
Quiche.memoryAddressWithPosition(sendInfo));
10571056
if (written == 0) {
10581057
out.release();
@@ -1175,7 +1174,7 @@ private boolean connectionSendSimple() {
11751174
int writerIndex = out.writerIndex();
11761175

11771176
int written = Quiche.quiche_conn_send(
1178-
connAddr, Quiche.memoryAddress(out) + writerIndex, out.writableBytes(),
1177+
connAddr, Quiche.writerMemoryAddress(out), out.writableBytes(),
11791178
Quiche.memoryAddressWithPosition(sendInfo));
11801179

11811180
try {
@@ -1406,8 +1405,7 @@ void connectionRecv(InetSocketAddress recipient, InetSocketAddress sender, ByteB
14061405
tmpBuffer.writeBytes(buffer);
14071406
buffer = tmpBuffer;
14081407
}
1409-
int bufferReaderIndex = buffer.readerIndex();
1410-
long memoryAddress = Quiche.memoryAddress(buffer) + bufferReaderIndex;
1408+
long memoryAddress = Quiche.readerMemoryAddress(buffer);
14111409

14121410
ByteBuffer recvInfo = connection.nextRecvInfo();
14131411
QuicheRecvInfo.setRecvInfo(recvInfo, sender, recipient);
@@ -1463,7 +1461,7 @@ void connectionRecv(InetSocketAddress recipient, InetSocketAddress sender, ByteB
14631461
bufferReadable -= res;
14641462
} while (bufferReadable > 0);
14651463
} finally {
1466-
buffer.skipBytes((int) (memoryAddress - Quiche.memoryAddress(buffer)));
1464+
buffer.skipBytes((int) (memoryAddress - Quiche.readerMemoryAddress(buffer)));
14671465
if (tmpBuffer != null) {
14681466
tmpBuffer.release();
14691467
}

codec-classes-quic/src/main/java/io/netty/incubator/codec/quic/QuicheQuicServerCodec.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ private QuicheQuicChannel handleServer(ChannelHandlerContext ctx, InetSocketAddr
123123
int outWriterIndex = out.writerIndex();
124124

125125
int res = Quiche.quiche_negotiate_version(
126-
Quiche.memoryAddress(scid) + scid.readerIndex(), scid.readableBytes(),
127-
Quiche.memoryAddress(dcid) + dcid.readerIndex(), dcid.readableBytes(),
128-
Quiche.memoryAddress(out) + outWriterIndex, out.writableBytes());
126+
Quiche.readerMemoryAddress(scid), scid.readableBytes(),
127+
Quiche.readerMemoryAddress(dcid), dcid.readableBytes(),
128+
Quiche.writerMemoryAddress(out), out.writableBytes());
129129
if (res < 0) {
130130
out.release();
131131
Quiche.throwIfError(res);
@@ -150,12 +150,13 @@ private QuicheQuicChannel handleServer(ChannelHandlerContext ctx, InetSocketAddr
150150

151151
ByteBuf out = ctx.alloc().directBuffer(Quic.MAX_DATAGRAM_SIZE);
152152
int outWriterIndex = out.writerIndex();
153-
int written = Quiche.quiche_retry(Quiche.memoryAddress(scid) + scid.readerIndex(), scid.readableBytes(),
154-
Quiche.memoryAddress(dcid) + dcid.readerIndex(), dcid.readableBytes(),
155-
Quiche.memoryAddress(connIdBuffer) + connIdBuffer.readerIndex(), connIdBuffer.readableBytes(),
156-
Quiche.memoryAddress(mintTokenBuffer) + mintTokenBuffer.readerIndex(),
157-
mintTokenBuffer.readableBytes(),
158-
version, Quiche.memoryAddress(out) + outWriterIndex, out.writableBytes());
153+
int written = Quiche.quiche_retry(
154+
Quiche.readerMemoryAddress(scid), scid.readableBytes(),
155+
Quiche.readerMemoryAddress(dcid), dcid.readableBytes(),
156+
Quiche.readerMemoryAddress(connIdBuffer), connIdBuffer.readableBytes(),
157+
Quiche.readerMemoryAddress(mintTokenBuffer), mintTokenBuffer.readableBytes(),
158+
version,
159+
Quiche.writerMemoryAddress(out), out.writableBytes());
159160

160161
if (written < 0) {
161162
out.release();
@@ -188,7 +189,7 @@ private QuicheQuicChannel handleServer(ChannelHandlerContext ctx, InetSocketAddr
188189
key = connectionIdAddressGenerator.newId(
189190
dcid.internalNioBuffer(dcid.readerIndex(), dcid.readableBytes()), localConnIdLength);
190191
connIdBuffer.writeBytes(key.duplicate());
191-
scidAddr = Quiche.memoryAddress(connIdBuffer) + connIdBuffer.readerIndex();
192+
scidAddr = Quiche.readerMemoryAddress(connIdBuffer);
192193
scidLen = localConnIdLength;
193194
ocidAddr = -1;
194195
ocidLen = -1;
@@ -198,9 +199,9 @@ private QuicheQuicChannel handleServer(ChannelHandlerContext ctx, InetSocketAddr
198199
return existingChannel;
199200
}
200201
} else {
201-
scidAddr = Quiche.memoryAddress(dcid) + dcid.readerIndex();
202+
scidAddr = Quiche.readerMemoryAddress(dcid);
202203
scidLen = localConnIdLength;
203-
ocidAddr = Quiche.memoryAddress(token) + offset;
204+
ocidAddr = Quiche.memoryAddress(token, offset, token.readableBytes());
204205
ocidLen = token.readableBytes() - offset;
205206
// Now create the key to store the channel in the map.
206207
byte[] bytes = new byte[localConnIdLength];

0 commit comments

Comments
 (0)