Skip to content

Commit 186be58

Browse files
committed
fix test failure
1 parent 3fc31d8 commit 186be58

File tree

1 file changed

+86
-56
lines changed

1 file changed

+86
-56
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala

Lines changed: 86 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ import java.net.ServerSocket
2222
import java.sql.Timestamp
2323
import java.util.concurrent.LinkedBlockingQueue
2424

25+
import scala.collection.mutable
26+
2527
import org.scalatest.BeforeAndAfterEach
2628

2729
import org.apache.spark.internal.Logging
2830
import org.apache.spark.sql.AnalysisException
31+
import org.apache.spark.sql.internal.SQLConf._
2932
import org.apache.spark.sql.streaming.StreamTest
3033
import org.apache.spark.sql.test.SharedSQLContext
3134
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
@@ -60,29 +63,31 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
6063

6164
source = provider.createSource(sqlContext, "", None, "", parameters)
6265

63-
failAfter(streamingTimeout) {
64-
serverThread.enqueue("hello")
65-
while (source.getOffset.isEmpty) {
66-
Thread.sleep(10)
67-
}
68-
val offset1 = source.getOffset.get
69-
val batch1 = source.getBatch(None, offset1)
70-
assert(batch1.as[String].collect().toSeq === Seq("hello"))
66+
withAdditionalConf(Map(UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false")) {() =>
67+
failAfter(streamingTimeout) {
68+
serverThread.enqueue("hello")
69+
while (source.getOffset.isEmpty) {
70+
Thread.sleep(10)
71+
}
72+
val offset1 = source.getOffset.get
73+
val batch1 = source.getBatch(None, offset1)
74+
assert(batch1.as[String].collect().toSeq === Seq("hello"))
7175

72-
serverThread.enqueue("world")
73-
while (source.getOffset.get === offset1) {
74-
Thread.sleep(10)
75-
}
76-
val offset2 = source.getOffset.get
77-
val batch2 = source.getBatch(Some(offset1), offset2)
78-
assert(batch2.as[String].collect().toSeq === Seq("world"))
76+
serverThread.enqueue("world")
77+
while (source.getOffset.get === offset1) {
78+
Thread.sleep(10)
79+
}
80+
val offset2 = source.getOffset.get
81+
val batch2 = source.getBatch(Some(offset1), offset2)
82+
assert(batch2.as[String].collect().toSeq === Seq("world"))
7983

80-
val both = source.getBatch(None, offset2)
81-
assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
84+
val both = source.getBatch(None, offset2)
85+
assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
8286

83-
// Try stopping the source to make sure this does not block forever.
84-
source.stop()
85-
source = null
87+
// Try stopping the source to make sure this does not block forever.
88+
source.stop()
89+
source = null
90+
}
8691
}
8792
}
8893

@@ -99,31 +104,33 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
99104

100105
source = provider.createSource(sqlContext, "", None, "", parameters)
101106

102-
failAfter(streamingTimeout) {
103-
serverThread.enqueue("hello")
104-
while (source.getOffset.isEmpty) {
105-
Thread.sleep(10)
106-
}
107-
val offset1 = source.getOffset.get
108-
val batch1 = source.getBatch(None, offset1)
109-
val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq
110-
assert(batch1Seq.map(_._1) === Seq("hello"))
111-
val batch1Stamp = batch1Seq(0)._2
112-
113-
serverThread.enqueue("world")
114-
while (source.getOffset.get === offset1) {
115-
Thread.sleep(10)
107+
withAdditionalConf(Map(UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false")) { () =>
108+
failAfter(streamingTimeout) {
109+
serverThread.enqueue("hello")
110+
while (source.getOffset.isEmpty) {
111+
Thread.sleep(10)
112+
}
113+
val offset1 = source.getOffset.get
114+
val batch1 = source.getBatch(None, offset1)
115+
val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq
116+
assert(batch1Seq.map(_._1) === Seq("hello"))
117+
val batch1Stamp = batch1Seq(0)._2
118+
119+
serverThread.enqueue("world")
120+
while (source.getOffset.get === offset1) {
121+
Thread.sleep(10)
122+
}
123+
val offset2 = source.getOffset.get
124+
val batch2 = source.getBatch(Some(offset1), offset2)
125+
val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq
126+
assert(batch2Seq.map(_._1) === Seq("world"))
127+
val batch2Stamp = batch2Seq(0)._2
128+
assert(!batch2Stamp.before(batch1Stamp))
129+
130+
// Try stopping the source to make sure this does not block forever.
131+
source.stop()
132+
source = null
116133
}
117-
val offset2 = source.getOffset.get
118-
val batch2 = source.getBatch(Some(offset1), offset2)
119-
val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq
120-
assert(batch2Seq.map(_._1) === Seq("world"))
121-
val batch2Stamp = batch2Seq(0)._2
122-
assert(!batch2Stamp.before(batch1Stamp))
123-
124-
// Try stopping the source to make sure this does not block forever.
125-
source.stop()
126-
source = null
127134
}
128135
}
129136

@@ -164,19 +171,42 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
164171
val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString)
165172
source = provider.createSource(sqlContext, "", None, "", parameters)
166173

167-
failAfter(streamingTimeout) {
168-
serverThread.enqueue("hello")
169-
while (source.getOffset.isEmpty) {
170-
Thread.sleep(10)
174+
withAdditionalConf(Map(UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false")) { () =>
175+
failAfter(streamingTimeout) {
176+
serverThread.enqueue("hello")
177+
while (source.getOffset.isEmpty) {
178+
Thread.sleep(10)
179+
}
180+
val batch = source.getBatch(None, source.getOffset.get).as[String]
181+
batch.collect()
182+
val numRowsMetric =
183+
batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
184+
assert(numRowsMetric.nonEmpty)
185+
assert(numRowsMetric.get.value === 1)
186+
source.stop()
187+
source = null
171188
}
172-
val batch = source.getBatch(None, source.getOffset.get).as[String]
173-
batch.collect()
174-
val numRowsMetric =
175-
batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
176-
assert(numRowsMetric.nonEmpty)
177-
assert(numRowsMetric.get.value === 1)
178-
source.stop()
179-
source = null
189+
}
190+
}
191+
192+
def withAdditionalConf(additionalConf: Map[String, String] = Map.empty)(f: () => Unit): Unit = {
193+
val resetConfValues = mutable.Map[String, Option[String]]()
194+
val conf = sqlContext.sparkSession.conf
195+
additionalConf.foreach(pair => {
196+
val value = if (conf.contains(pair._1)) {
197+
Some(conf.get(pair._1))
198+
} else {
199+
None
200+
}
201+
resetConfValues(pair._1) = value
202+
conf.set(pair._1, pair._2)
203+
})
204+
205+
f()
206+
207+
resetConfValues.foreach {
208+
case (key, Some(value)) => conf.set(key, value)
209+
case (key, None) => conf.unset(key)
180210
}
181211
}
182212

0 commit comments

Comments
 (0)