Skip to content

Commit 04c2de6

Browse files
authored
Fix OutputSampler's coder. (apache#25805)
1 parent afce68d commit 04c2de6

4 files changed

Lines changed: 110 additions & 47 deletions

File tree

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,14 @@ public FnDataReceiver<WindowedValue<?>> getMultiplexingConsumer(String pCollecti
195195
String coderId =
196196
processBundleDescriptor.getPcollectionsOrThrow(pCollectionId).getCoderId();
197197
Coder<?> coder;
198+
OutputSampler<?> sampler = null;
198199
try {
199200
Coder<?> maybeWindowedValueInputCoder = rehydratedComponents.getCoder(coderId);
201+
202+
if (dataSampler != null) {
203+
sampler = dataSampler.sampleOutput(pCollectionId, maybeWindowedValueInputCoder);
204+
}
205+
200206
// TODO: Stop passing windowed value coders within PCollections.
201207
if (maybeWindowedValueInputCoder instanceof WindowedValue.WindowedValueCoder) {
202208
coder = ((WindowedValueCoder) maybeWindowedValueInputCoder).getValueCoder();
@@ -215,16 +221,16 @@ public FnDataReceiver<WindowedValue<?>> getMultiplexingConsumer(String pCollecti
215221
ConsumerAndMetadata consumerAndMetadata = consumerAndMetadatas.get(0);
216222
if (consumerAndMetadata.getConsumer() instanceof HandlesSplits) {
217223
return new SplittingMetricTrackingFnDataReceiver(
218-
pcId, coder, consumerAndMetadata, dataSampler);
224+
pcId, coder, consumerAndMetadata, sampler);
219225
}
220-
return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata, dataSampler);
226+
return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata, sampler);
221227
} else {
222228
/* TODO(SDF), Consider supporting splitting each consumer individually. This would never
223229
come up in the existing SDF expansion, but might be useful to support fused SDF nodes.
224230
This would require dedicated delivery of the split results to each of the consumers
225231
separately. */
226232
return new MultiplexingMetricTrackingFnDataReceiver(
227-
pcId, coder, consumerAndMetadatas, dataSampler);
233+
pcId, coder, consumerAndMetadatas, sampler);
228234
}
229235
});
230236
}
@@ -248,7 +254,7 @@ public MetricTrackingFnDataReceiver(
248254
String pCollectionId,
249255
Coder<T> coder,
250256
ConsumerAndMetadata consumerAndMetadata,
251-
@Nullable DataSampler dataSampler) {
257+
@Nullable OutputSampler<T> outputSampler) {
252258
this.delegate = consumerAndMetadata.getConsumer();
253259
this.executionState = consumerAndMetadata.getExecutionState();
254260

@@ -284,11 +290,7 @@ public MetricTrackingFnDataReceiver(
284290
bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution);
285291

286292
this.coder = coder;
287-
if (dataSampler == null) {
288-
this.outputSampler = null;
289-
} else {
290-
this.outputSampler = dataSampler.sampleOutput(pCollectionId, coder);
291-
}
293+
this.outputSampler = outputSampler;
292294
}
293295

294296
@Override
@@ -300,7 +302,7 @@ public void accept(WindowedValue<T> input) throws Exception {
300302
this.sampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
301303

302304
if (outputSampler != null) {
303-
outputSampler.sample(input.getValue());
305+
outputSampler.sample(input);
304306
}
305307

306308
// Use the ExecutionStateTracker and enter an appropriate state to track the
@@ -329,13 +331,13 @@ private class MultiplexingMetricTrackingFnDataReceiver<T>
329331
private final BundleCounter elementCountCounter;
330332
private final SampleByteSizeDistribution<T> sampledByteSizeDistribution;
331333
private final Coder<T> coder;
332-
private final @Nullable OutputSampler<T> outputSampler;
334+
private @Nullable OutputSampler<T> outputSampler = null;
333335

334336
public MultiplexingMetricTrackingFnDataReceiver(
335337
String pCollectionId,
336338
Coder<T> coder,
337339
List<ConsumerAndMetadata> consumerAndMetadatas,
338-
@Nullable DataSampler dataSampler) {
340+
@Nullable OutputSampler<T> outputSampler) {
339341
this.consumerAndMetadatas = consumerAndMetadatas;
340342

341343
HashMap<String, String> labels = new HashMap<>();
@@ -370,11 +372,7 @@ public MultiplexingMetricTrackingFnDataReceiver(
370372
bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution);
371373

372374
this.coder = coder;
373-
if (dataSampler == null) {
374-
this.outputSampler = null;
375-
} else {
376-
this.outputSampler = dataSampler.sampleOutput(pCollectionId, coder);
377-
}
375+
this.outputSampler = outputSampler;
378376
}
379377

380378
@Override
@@ -386,7 +384,7 @@ public void accept(WindowedValue<T> input) throws Exception {
386384
this.sampledByteSizeDistribution.tryUpdate(input.getValue(), coder);
387385

388386
if (outputSampler != null) {
389-
outputSampler.sample(input.getValue());
387+
outputSampler.sample(input);
390388
}
391389

392390
// Use the ExecutionStateTracker and enter an appropriate state to track the
@@ -422,8 +420,8 @@ public SplittingMetricTrackingFnDataReceiver(
422420
String pCollection,
423421
Coder<T> coder,
424422
ConsumerAndMetadata consumerAndMetadata,
425-
@Nullable DataSampler dataSampler) {
426-
super(pCollection, coder, consumerAndMetadata, dataSampler);
423+
@Nullable OutputSampler<T> outputSampler) {
424+
super(pCollection, coder, consumerAndMetadata, outputSampler);
427425
this.delegate = (HandlesSplits) consumerAndMetadata.getConsumer();
428426
}
429427

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/OutputSampler.java

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import java.util.ArrayList;
2222
import java.util.List;
2323
import java.util.concurrent.atomic.AtomicLong;
24+
import javax.annotation.Nullable;
2425
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
2526
import org.apache.beam.sdk.coders.Coder;
2627
import org.apache.beam.sdk.util.ByteStringOutputStream;
28+
import org.apache.beam.sdk.util.WindowedValue;
2729

2830
/**
2931
* This class holds samples for a single PCollection until queried by the parent DataSampler. This
@@ -35,7 +37,7 @@
3537
public class OutputSampler<T> {
3638

3739
// Temporarily holds elements until the SDK receives a sample data request.
38-
private List<T> buffer;
40+
private List<WindowedValue<T>> buffer;
3941

4042
// Maximum number of elements in buffer.
4143
private final int maxElements;
@@ -49,13 +51,27 @@ public class OutputSampler<T> {
4951
// Index into the buffer of where to overwrite samples.
5052
private int resampleIndex = 0;
5153

52-
private final Coder<T> coder;
54+
@Nullable private final Coder<T> valueCoder;
5355

54-
public OutputSampler(Coder<T> coder, int maxElements, int sampleEveryN) {
55-
this.coder = coder;
56+
@Nullable private final Coder<WindowedValue<T>> windowedValueCoder;
57+
58+
public OutputSampler(Coder<?> coder, int maxElements, int sampleEveryN) {
5659
this.maxElements = maxElements;
5760
this.sampleEveryN = sampleEveryN;
5861
this.buffer = new ArrayList<>(this.maxElements);
62+
63+
// The samples taken and encoded should match exactly to the specification from the
64+
// ProcessBundleDescriptor. The coder given can either be a WindowedValueCoder, in which the
65+
// element itself is sampled. Or, it's non a WindowedValueCoder and the value inside the
66+
// windowed value must be sampled. This is because WindowedValue is the element type used in
67+
// all receivers, which doesn't necessarily match the PBD encoding.
68+
if (coder instanceof WindowedValue.WindowedValueCoder) {
69+
this.valueCoder = null;
70+
this.windowedValueCoder = (Coder<WindowedValue<T>>) coder;
71+
} else {
72+
this.valueCoder = (Coder<T>) coder;
73+
this.windowedValueCoder = null;
74+
}
5975
}
6076

6177
/**
@@ -67,7 +83,7 @@ public OutputSampler(Coder<T> coder, int maxElements, int sampleEveryN) {
6783
*
6884
* @param element the element to sample.
6985
*/
70-
public void sample(T element) {
86+
public void sample(WindowedValue<T> element) {
7187
// Only sample the first 10 elements then after every `sampleEveryN`th element.
7288
long samples = numSamples.get() + 1;
7389

@@ -104,7 +120,7 @@ public List<BeamFnApi.SampledElement> samples() throws IOException {
104120

105121
// Serializing can take a lot of CPU time for larger or complex elements. Copy the array here
106122
// so as to not slow down the main processing hot path.
107-
List<T> bufferToSend;
123+
List<WindowedValue<T>> bufferToSend;
108124
int sampleIndex = 0;
109125
synchronized (this) {
110126
bufferToSend = buffer;
@@ -116,10 +132,13 @@ public List<BeamFnApi.SampledElement> samples() throws IOException {
116132
ByteStringOutputStream stream = new ByteStringOutputStream();
117133
for (int i = 0; i < bufferToSend.size(); i++) {
118134
int index = (sampleIndex + i) % bufferToSend.size();
119-
// This is deprecated, but until this is fully removed, this specifically needs the nested
120-
// context. This is because the SDK will need to decode the sampled elements with the
121-
// ToStringFn.
122-
coder.encode(bufferToSend.get(index), stream, Coder.Context.NESTED);
135+
136+
if (valueCoder != null) {
137+
this.valueCoder.encode(bufferToSend.get(index).getValue(), stream, Coder.Context.NESTED);
138+
} else if (windowedValueCoder != null) {
139+
this.windowedValueCoder.encode(bufferToSend.get(index), stream, Coder.Context.NESTED);
140+
}
141+
123142
ret.add(
124143
BeamFnApi.SampledElement.newBuilder().setElement(stream.toByteStringAndReset()).build());
125144
}

sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/DataSamplerTest.java

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.apache.beam.sdk.coders.Coder;
3737
import org.apache.beam.sdk.coders.StringUtf8Coder;
3838
import org.apache.beam.sdk.coders.VarIntCoder;
39+
import org.apache.beam.sdk.util.WindowedValue;
3940
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
4041
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
4142
import org.junit.Test;
@@ -65,6 +66,10 @@ byte[] encodeByteArray(byte[] b) throws IOException {
6566
return stream.toByteArray();
6667
}
6768

69+
<T> WindowedValue<T> globalWindowedValue(T el) {
70+
return WindowedValue.valueInGlobalWindow(el);
71+
}
72+
6873
BeamFnApi.InstructionResponse getAllSamples(DataSampler dataSampler) {
6974
BeamFnApi.InstructionRequest request =
7075
BeamFnApi.InstructionRequest.newBuilder()
@@ -122,7 +127,7 @@ public void testSingleOutput() throws Exception {
122127
DataSampler sampler = new DataSampler();
123128

124129
VarIntCoder coder = VarIntCoder.of();
125-
sampler.sampleOutput("pcollection-id", coder).sample(1);
130+
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(1));
126131

127132
BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
128133
assertHasSamples(samples, "pcollection-id", Collections.singleton(encodeInt(1)));
@@ -140,7 +145,7 @@ public void testNestedContext() throws Exception {
140145
String rawString = "hello";
141146
byte[] byteArray = rawString.getBytes(StandardCharsets.US_ASCII);
142147
ByteArrayCoder coder = ByteArrayCoder.of();
143-
sampler.sampleOutput("pcollection-id", coder).sample(byteArray);
148+
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(byteArray));
144149

145150
BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
146151
assertHasSamples(samples, "pcollection-id", Collections.singleton(encodeByteArray(byteArray)));
@@ -156,8 +161,8 @@ public void testMultipleOutputs() throws Exception {
156161
DataSampler sampler = new DataSampler();
157162

158163
VarIntCoder coder = VarIntCoder.of();
159-
sampler.sampleOutput("pcollection-id-1", coder).sample(1);
160-
sampler.sampleOutput("pcollection-id-2", coder).sample(2);
164+
sampler.sampleOutput("pcollection-id-1", coder).sample(globalWindowedValue(1));
165+
sampler.sampleOutput("pcollection-id-2", coder).sample(globalWindowedValue(2));
161166

162167
BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
163168
assertHasSamples(samples, "pcollection-id-1", Collections.singleton(encodeInt(1)));
@@ -174,21 +179,21 @@ public void testMultipleSamePCollections() throws Exception {
174179
DataSampler sampler = new DataSampler();
175180

176181
VarIntCoder coder = VarIntCoder.of();
177-
sampler.sampleOutput("pcollection-id", coder).sample(1);
178-
sampler.sampleOutput("pcollection-id", coder).sample(2);
182+
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(1));
183+
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(2));
179184

180185
BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
181186
assertHasSamples(samples, "pcollection-id", ImmutableList.of(encodeInt(1), encodeInt(2)));
182187
}
183188

184189
void generateStringSamples(DataSampler sampler) {
185190
StringUtf8Coder coder = StringUtf8Coder.of();
186-
sampler.sampleOutput("a", coder).sample("a1");
187-
sampler.sampleOutput("a", coder).sample("a2");
188-
sampler.sampleOutput("b", coder).sample("b1");
189-
sampler.sampleOutput("b", coder).sample("b2");
190-
sampler.sampleOutput("c", coder).sample("c1");
191-
sampler.sampleOutput("c", coder).sample("c2");
191+
sampler.sampleOutput("a", coder).sample(globalWindowedValue("a1"));
192+
sampler.sampleOutput("a", coder).sample(globalWindowedValue("a2"));
193+
sampler.sampleOutput("b", coder).sample(globalWindowedValue("b1"));
194+
sampler.sampleOutput("b", coder).sample(globalWindowedValue("b2"));
195+
sampler.sampleOutput("c", coder).sample(globalWindowedValue("c1"));
196+
sampler.sampleOutput("c", coder).sample(globalWindowedValue("c2"));
192197
}
193198

194199
/**
@@ -250,7 +255,7 @@ public void testConcurrentNewSampler() throws Exception {
250255
}
251256

252257
for (int j = 0; j < 100; j++) {
253-
sampler.sampleOutput("pcollection-" + j, coder).sample(0);
258+
sampler.sampleOutput("pcollection-" + j, coder).sample(globalWindowedValue(0));
254259
}
255260

256261
doneSignal.countDown();

sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/OutputSamplerTest.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
import java.util.concurrent.CountDownLatch;
3030
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
3131
import org.apache.beam.sdk.coders.VarIntCoder;
32+
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
33+
import org.apache.beam.sdk.util.WindowedValue;
3234
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
35+
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
3336
import org.junit.Test;
3437
import org.junit.runner.RunWith;
3538
import org.junit.runners.JUnit4;
@@ -45,6 +48,17 @@ public BeamFnApi.SampledElement encodeInt(Integer i) throws IOException {
4548
.build();
4649
}
4750

51+
public BeamFnApi.SampledElement encodeGlobalWindowedInt(Integer i) throws IOException {
52+
WindowedValue.WindowedValueCoder<Integer> coder =
53+
WindowedValue.FullWindowedValueCoder.of(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE);
54+
55+
ByteArrayOutputStream stream = new ByteArrayOutputStream();
56+
coder.encode(WindowedValue.valueInGlobalWindow(i), stream);
57+
return BeamFnApi.SampledElement.newBuilder()
58+
.setElement(ByteString.copyFrom(stream.toByteArray()))
59+
.build();
60+
}
61+
4862
/**
4963
* Test that the first N are always sampled.
5064
*
@@ -57,7 +71,7 @@ public void testSamplesFirstN() throws Exception {
5771

5872
// Purposely go over maxSamples and sampleEveryN. This helps to increase confidence.
5973
for (int i = 0; i < 15; ++i) {
60-
outputSampler.sample(i);
74+
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
6175
}
6276

6377
// The expected list is only 0..9 inclusive.
@@ -70,6 +84,33 @@ public void testSamplesFirstN() throws Exception {
7084
assertThat(samples, containsInAnyOrder(expected.toArray()));
7185
}
7286

87+
@Test
88+
public void testWindowedValueSample() throws Exception {
89+
WindowedValue.WindowedValueCoder<Integer> coder =
90+
WindowedValue.FullWindowedValueCoder.of(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE);
91+
92+
OutputSampler<Integer> outputSampler = new OutputSampler<>(coder, 10, 10);
93+
outputSampler.sample(WindowedValue.valueInGlobalWindow(0));
94+
95+
// The expected list is only 0..9 inclusive.
96+
List<BeamFnApi.SampledElement> expected = ImmutableList.of(encodeGlobalWindowedInt(0));
97+
List<BeamFnApi.SampledElement> samples = outputSampler.samples();
98+
assertThat(samples, containsInAnyOrder(expected.toArray()));
99+
}
100+
101+
@Test
102+
public void testNonWindowedValueSample() throws Exception {
103+
VarIntCoder coder = VarIntCoder.of();
104+
105+
OutputSampler<Integer> outputSampler = new OutputSampler<>(coder, 10, 10);
106+
outputSampler.sample(WindowedValue.valueInGlobalWindow(0));
107+
108+
// The expected list is only 0..9 inclusive.
109+
List<BeamFnApi.SampledElement> expected = ImmutableList.of(encodeInt(0));
110+
List<BeamFnApi.SampledElement> samples = outputSampler.samples();
111+
assertThat(samples, containsInAnyOrder(expected.toArray()));
112+
}
113+
73114
/**
74115
* Test that the previous values are overwritten and only the most recent `maxSamples` are kept.
75116
*
@@ -81,7 +122,7 @@ public void testActsLikeCircularBuffer() throws Exception {
81122
OutputSampler<Integer> outputSampler = new OutputSampler<>(coder, 5, 20);
82123

83124
for (int i = 0; i < 100; ++i) {
84-
outputSampler.sample(i);
125+
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
85126
}
86127

87128
// The first 10 are always sampled, but with maxSamples = 5, the first ten are downsampled to
@@ -124,7 +165,7 @@ public void testConcurrentSamples() throws Exception {
124165
}
125166

126167
for (int i = 0; i < 1000000; i++) {
127-
outputSampler.sample(i);
168+
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
128169
}
129170

130171
doneSignal.countDown();
@@ -141,7 +182,7 @@ public void testConcurrentSamples() throws Exception {
141182
}
142183

143184
for (int i = -1000000; i < 0; i++) {
144-
outputSampler.sample(i);
185+
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
145186
}
146187

147188
doneSignal.countDown();

0 commit comments

Comments
 (0)