Skip to content

Commit 765bfb1

Browse files
committed
protobuf,protobuf-lite: configurable protobuf recursion limit
1 parent 27edab0 commit 765bfb1

4 files changed

Lines changed: 153 additions & 20 deletions

File tree

protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,17 @@ public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) {
8181
*/
8282
public static <T extends MessageLite> Marshaller<T> marshaller(T defaultInstance) {
8383
// TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe)
84-
return new MessageMarshaller<>(defaultInstance);
84+
return new MessageMarshaller<>(defaultInstance, -1);
85+
}
86+
87+
/**
88+
* Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a
89+
* custom limit for the recursion depth. Any negative number will leave the limit to its default
90+
* value as defined by the protobuf library.
91+
*/
92+
public static <T extends MessageLite> Marshaller<T> marshallerWithRecursionLimit(
93+
T defaultInstance, int recursionLimit) {
94+
return new MessageMarshaller<>(defaultInstance, recursionLimit);
8595
}
8696

8797
/**
@@ -94,7 +104,9 @@ public static <T extends MessageLite> Metadata.BinaryMarshaller<T> metadataMarsh
94104
return new MetadataMarshaller<>(defaultInstance);
95105
}
96106

97-
/** Copies the data from input stream to output stream. */
107+
/**
108+
* Copies the data from input stream to output stream.
109+
*/
98110
static long copy(InputStream from, OutputStream to) throws IOException {
99111
// Copied from guava com.google.common.io.ByteStreams because its API is unstable (beta)
100112
checkNotNull(from, "inputStream cannot be null!");
@@ -117,18 +129,20 @@ private ProtoLiteUtils() {
117129

118130
private static final class MessageMarshaller<T extends MessageLite>
119131
implements PrototypeMarshaller<T> {
132+
120133
private static final ThreadLocal<Reference<byte[]>> bufs = new ThreadLocal<>();
121134

122135
private final Parser<T> parser;
123136
private final T defaultInstance;
137+
private final int recursionLimit;
124138

125139
@SuppressWarnings("unchecked")
126-
MessageMarshaller(T defaultInstance) {
127-
this.defaultInstance = defaultInstance;
128-
parser = (Parser<T>) defaultInstance.getParserForType();
140+
MessageMarshaller(T defaultInstance, int recursionLimit) {
141+
this.defaultInstance = checkNotNull(defaultInstance, "defaultInstance cannot be null");
142+
this.parser = (Parser<T>) defaultInstance.getParserForType();
143+
this.recursionLimit = recursionLimit;
129144
}
130145

131-
132146
@SuppressWarnings("unchecked")
133147
@Override
134148
public Class<T> getMessageClass() {
@@ -211,6 +225,10 @@ public T parse(InputStream stream) {
211225
// when parsing.
212226
cis.setSizeLimit(Integer.MAX_VALUE);
213227

228+
if (recursionLimit >= 0) {
229+
cis.setRecursionLimit(recursionLimit);
230+
}
231+
214232
try {
215233
return parseFrom(cis);
216234
} catch (InvalidProtocolBufferException ipbe) {

protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static org.junit.Assert.assertEquals;
2121
import static org.junit.Assert.assertNotNull;
2222
import static org.junit.Assert.assertSame;
23+
import static org.junit.Assert.assertThrows;
2324
import static org.junit.Assert.fail;
2425

2526
import com.google.common.io.ByteStreams;
@@ -36,6 +37,8 @@
3637
import io.grpc.Status;
3738
import io.grpc.StatusRuntimeException;
3839
import io.grpc.internal.GrpcUtil;
40+
import io.grpc.testing.protobuf.SimpleRecursiveMessage;
41+
import io.grpc.testing.protobuf.SimpleRecursiveMessage.Builder;
3942
import java.io.ByteArrayInputStream;
4043
import java.io.ByteArrayOutputStream;
4144
import java.io.IOException;
@@ -47,14 +50,17 @@
4750
import org.junit.runner.RunWith;
4851
import org.junit.runners.JUnit4;
4952

50-
/** Unit tests for {@link ProtoLiteUtils}. */
53+
/**
54+
* Unit tests for {@link ProtoLiteUtils}.
55+
*/
5156
@RunWith(JUnit4.class)
5257
public class ProtoLiteUtilsTest {
5358

5459
@SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
55-
@Rule public final ExpectedException thrown = ExpectedException.none();
60+
@Rule
61+
public final ExpectedException thrown = ExpectedException.none();
5662

57-
private Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
63+
private final Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
5864
private Type proto = Type.newBuilder().setName("name").build();
5965

6066
@Test
@@ -85,8 +91,8 @@ public void testInvalidatedMessage() throws Exception {
8591
}
8692

8793
@Test
88-
public void parseInvalid() throws Exception {
89-
InputStream is = new ByteArrayInputStream(new byte[] {-127});
94+
public void parseInvalid() {
95+
InputStream is = new ByteArrayInputStream(new byte[]{-127});
9096
try {
9197
marshaller.parse(is);
9298
fail("Expected exception");
@@ -97,15 +103,15 @@ public void parseInvalid() throws Exception {
97103
}
98104

99105
@Test
100-
public void testMismatch() throws Exception {
106+
public void testMismatch() {
101107
Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
102108
// Enum's name and Type's name are both strings with tag 1.
103109
Enum altProto = Enum.newBuilder().setName(proto.getName()).build();
104110
assertEquals(proto, marshaller.parse(enumMarshaller.stream(altProto)));
105111
}
106112

107113
@Test
108-
public void introspection() throws Exception {
114+
public void introspection() {
109115
Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
110116
PrototypeMarshaller<Enum> prototypeMarshaller = (PrototypeMarshaller<Enum>) enumMarshaller;
111117
assertSame(Enum.getDefaultInstance(), prototypeMarshaller.getMessagePrototype());
@@ -130,7 +136,8 @@ public void testAvailable() throws Exception {
130136
assertEquals(proto.getSerializedSize(), is.available());
131137
is.read();
132138
assertEquals(proto.getSerializedSize() - 1, is.available());
133-
while (is.read() != -1) {}
139+
while (is.read() != -1) {
140+
}
134141
assertEquals(-1, is.read());
135142
assertEquals(0, is.available());
136143
}
@@ -203,7 +210,7 @@ public void metadataMarshaller_invalid() {
203210
Metadata.BinaryMarshaller<Type> metadataMarshaller =
204211
ProtoLiteUtils.metadataMarshaller(Type.getDefaultInstance());
205212
try {
206-
metadataMarshaller.parseBytes(new byte[] {-127});
213+
metadataMarshaller.parseBytes(new byte[]{-127});
207214
fail("Expected exception");
208215
} catch (IllegalArgumentException ex) {
209216
assertNotNull(((InvalidProtocolBufferException) ex.getCause()).getUnfinishedMessage());
@@ -219,7 +226,7 @@ public void extensionRegistry_notNull() {
219226
}
220227

221228
@Test
222-
public void parseFromKnowLengthInputStream() throws Exception {
229+
public void parseFromKnowLengthInputStream() {
223230
Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
224231
Type expect = Type.newBuilder().setName("expected name").build();
225232

@@ -232,21 +239,106 @@ public void defaultMaxMessageSize() {
232239
assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE);
233240
}
234241

242+
@Test
243+
public void testNullDefaultInstance() {
244+
String expectedMessage = "defaultInstance cannot be null";
245+
assertThrows(expectedMessage, NullPointerException.class,
246+
() -> ProtoLiteUtils.marshaller(null));
247+
248+
assertThrows(expectedMessage, NullPointerException.class,
249+
() -> ProtoLiteUtils.marshallerWithRecursionLimit(null, 10)
250+
);
251+
}
252+
253+
@Test
254+
public void givenPositiveLimit_testRecursionLimitExceeded() throws IOException {
255+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
256+
SimpleRecursiveMessage.getDefaultInstance(), 10);
257+
SimpleRecursiveMessage message = buildRecursiveMessage(12);
258+
259+
assertRecursionLimitExceeded(marshaller, message);
260+
}
261+
262+
@Test
263+
public void givenZeroLimit_testRecursionLimitExceeded() throws IOException {
264+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
265+
SimpleRecursiveMessage.getDefaultInstance(), 0);
266+
SimpleRecursiveMessage message = buildRecursiveMessage(1);
267+
268+
assertRecursionLimitExceeded(marshaller, message);
269+
}
270+
271+
@Test
272+
public void givenPositiveLimit_testRecursionLimitNotExceeded() throws IOException {
273+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
274+
SimpleRecursiveMessage.getDefaultInstance(), 15);
275+
SimpleRecursiveMessage message = buildRecursiveMessage(12);
276+
277+
assertRecursionLimitNotExceeded(marshaller, message);
278+
}
279+
280+
@Test
281+
public void givenZeroLimit_testRecursionLimitNotExceeded() throws IOException {
282+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
283+
SimpleRecursiveMessage.getDefaultInstance(), 0);
284+
SimpleRecursiveMessage message = buildRecursiveMessage(0);
285+
286+
assertRecursionLimitNotExceeded(marshaller, message);
287+
}
288+
289+
@Test
290+
public void testDefaultRecursionLimit() throws IOException {
291+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshaller(
292+
SimpleRecursiveMessage.getDefaultInstance());
293+
SimpleRecursiveMessage message = buildRecursiveMessage(100);
294+
295+
assertRecursionLimitNotExceeded(marshaller, message);
296+
}
297+
298+
private static void assertRecursionLimitExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
299+
SimpleRecursiveMessage message) throws IOException {
300+
InputStream is = marshaller.stream(message);
301+
ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));
302+
303+
assertThrows(StatusRuntimeException.class, () -> marshaller.parse(bais));
304+
}
305+
306+
private static void assertRecursionLimitNotExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
307+
SimpleRecursiveMessage message) throws IOException {
308+
InputStream is = marshaller.stream(message);
309+
ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));
310+
311+
assertEquals(message, marshaller.parse(bais));
312+
}
313+
314+
private static SimpleRecursiveMessage buildRecursiveMessage(int depth) {
315+
SimpleRecursiveMessage.Builder builder = SimpleRecursiveMessage.newBuilder()
316+
.setValue("depth-" + depth);
317+
for (int i = depth; i > 0; i--) {
318+
builder = SimpleRecursiveMessage.newBuilder()
319+
.setValue("depth-" + i)
320+
.setMessage(builder.build());
321+
}
322+
323+
return builder.build();
324+
}
325+
235326
private static class CustomKnownLengthInputStream extends InputStream implements KnownLength {
327+
236328
private int position = 0;
237-
private byte[] source;
329+
private final byte[] source;
238330

239331
private CustomKnownLengthInputStream(byte[] source) {
240332
this.source = source;
241333
}
242334

243335
@Override
244-
public int available() throws IOException {
336+
public int available() {
245337
return source.length - position;
246338
}
247339

248340
@Override
249-
public int read() throws IOException {
341+
public int read() {
250342
if (position == source.length) {
251343
return -1;
252344
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
package grpc.testing;
4+
5+
option java_package = "io.grpc.testing.protobuf";
6+
option java_outer_classname = "SimpleRecursiveProto";
7+
option java_multiple_files = true;
8+
9+
// A simple recursive message for testing purposes
10+
message SimpleRecursiveMessage {
11+
string value = 1;
12+
SimpleRecursiveMessage message = 2;
13+
}

protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ public static <T extends Message> Marshaller<T> marshaller(final T defaultInstan
5757
return ProtoLiteUtils.marshaller(defaultInstance);
5858
}
5959

60+
/**
61+
* Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a
62+
* custom limit for the recursion depth. Any negative number will leave the limit to its default
63+
* value as defined by the protobuf library.
64+
*/
65+
public static <T extends Message> Marshaller<T> marshallerWithRecursionLimit(T defaultInstance,
66+
int recursionLimit) {
67+
return ProtoLiteUtils.marshallerWithRecursionLimit(defaultInstance, recursionLimit);
68+
}
69+
6070
/**
6171
* Produce a metadata key for a generated protobuf type.
6272
*
@@ -70,7 +80,7 @@ public static <T extends Message> Metadata.Key<T> keyForProto(T instance) {
7080

7181
/**
7282
* Produce a metadata marshaller for a protobuf type.
73-
*
83+
*
7484
* @since 1.13.0
7585
*/
7686
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4477")

0 commit comments

Comments
 (0)