@@ -22,10 +22,13 @@ import java.net.ServerSocket
2222import java .sql .Timestamp
2323import java .util .concurrent .LinkedBlockingQueue
2424
25+ import scala .collection .mutable
26+
2527import org .scalatest .BeforeAndAfterEach
2628
2729import org .apache .spark .internal .Logging
2830import org .apache .spark .sql .AnalysisException
31+ import org .apache .spark .sql .internal .SQLConf ._
2932import org .apache .spark .sql .streaming .StreamTest
3033import org .apache .spark .sql .test .SharedSQLContext
3134import org .apache .spark .sql .types .{StringType , StructField , StructType , TimestampType }
@@ -60,29 +63,31 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
6063
6164 source = provider.createSource(sqlContext, " " , None , " " , parameters)
6265
63- failAfter(streamingTimeout) {
64- serverThread.enqueue(" hello" )
65- while (source.getOffset.isEmpty) {
66- Thread .sleep(10 )
67- }
68- val offset1 = source.getOffset.get
69- val batch1 = source.getBatch(None , offset1)
70- assert(batch1.as[String ].collect().toSeq === Seq (" hello" ))
66+ withAdditionalConf(Map (UNSUPPORTED_OPERATION_CHECK_ENABLED .key -> " false" )) {() =>
67+ failAfter(streamingTimeout) {
68+ serverThread.enqueue(" hello" )
69+ while (source.getOffset.isEmpty) {
70+ Thread .sleep(10 )
71+ }
72+ val offset1 = source.getOffset.get
73+ val batch1 = source.getBatch(None , offset1)
74+ assert(batch1.as[String ].collect().toSeq === Seq (" hello" ))
7175
72- serverThread.enqueue(" world" )
73- while (source.getOffset.get === offset1) {
74- Thread .sleep(10 )
75- }
76- val offset2 = source.getOffset.get
77- val batch2 = source.getBatch(Some (offset1), offset2)
78- assert(batch2.as[String ].collect().toSeq === Seq (" world" ))
76+ serverThread.enqueue(" world" )
77+ while (source.getOffset.get === offset1) {
78+ Thread .sleep(10 )
79+ }
80+ val offset2 = source.getOffset.get
81+ val batch2 = source.getBatch(Some (offset1), offset2)
82+ assert(batch2.as[String ].collect().toSeq === Seq (" world" ))
7983
80- val both = source.getBatch(None , offset2)
81- assert(both.as[String ].collect().sorted.toSeq === Seq (" hello" , " world" ))
84+ val both = source.getBatch(None , offset2)
85+ assert(both.as[String ].collect().sorted.toSeq === Seq (" hello" , " world" ))
8286
83- // Try stopping the source to make sure this does not block forever.
84- source.stop()
85- source = null
87+ // Try stopping the source to make sure this does not block forever.
88+ source.stop()
89+ source = null
90+ }
8691 }
8792 }
8893
@@ -99,31 +104,33 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
99104
100105 source = provider.createSource(sqlContext, " " , None , " " , parameters)
101106
102- failAfter(streamingTimeout) {
103- serverThread.enqueue(" hello" )
104- while (source.getOffset.isEmpty) {
105- Thread .sleep(10 )
106- }
107- val offset1 = source.getOffset.get
108- val batch1 = source.getBatch(None , offset1)
109- val batch1Seq = batch1.as[(String , Timestamp )].collect().toSeq
110- assert(batch1Seq.map(_._1) === Seq (" hello" ))
111- val batch1Stamp = batch1Seq(0 )._2
112-
113- serverThread.enqueue(" world" )
114- while (source.getOffset.get === offset1) {
115- Thread .sleep(10 )
107+ withAdditionalConf(Map (UNSUPPORTED_OPERATION_CHECK_ENABLED .key -> " false" )) { () =>
108+ failAfter(streamingTimeout) {
109+ serverThread.enqueue(" hello" )
110+ while (source.getOffset.isEmpty) {
111+ Thread .sleep(10 )
112+ }
113+ val offset1 = source.getOffset.get
114+ val batch1 = source.getBatch(None , offset1)
115+ val batch1Seq = batch1.as[(String , Timestamp )].collect().toSeq
116+ assert(batch1Seq.map(_._1) === Seq (" hello" ))
117+ val batch1Stamp = batch1Seq(0 )._2
118+
119+ serverThread.enqueue(" world" )
120+ while (source.getOffset.get === offset1) {
121+ Thread .sleep(10 )
122+ }
123+ val offset2 = source.getOffset.get
124+ val batch2 = source.getBatch(Some (offset1), offset2)
125+ val batch2Seq = batch2.as[(String , Timestamp )].collect().toSeq
126+ assert(batch2Seq.map(_._1) === Seq (" world" ))
127+ val batch2Stamp = batch2Seq(0 )._2
128+ assert(! batch2Stamp.before(batch1Stamp))
129+
130+ // Try stopping the source to make sure this does not block forever.
131+ source.stop()
132+ source = null
116133 }
117- val offset2 = source.getOffset.get
118- val batch2 = source.getBatch(Some (offset1), offset2)
119- val batch2Seq = batch2.as[(String , Timestamp )].collect().toSeq
120- assert(batch2Seq.map(_._1) === Seq (" world" ))
121- val batch2Stamp = batch2Seq(0 )._2
122- assert(! batch2Stamp.before(batch1Stamp))
123-
124- // Try stopping the source to make sure this does not block forever.
125- source.stop()
126- source = null
127134 }
128135 }
129136
@@ -164,19 +171,42 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
164171 val parameters = Map (" host" -> " localhost" , " port" -> serverThread.port.toString)
165172 source = provider.createSource(sqlContext, " " , None , " " , parameters)
166173
167- failAfter(streamingTimeout) {
168- serverThread.enqueue(" hello" )
169- while (source.getOffset.isEmpty) {
170- Thread .sleep(10 )
174+ withAdditionalConf(Map (UNSUPPORTED_OPERATION_CHECK_ENABLED .key -> " false" )) { () =>
175+ failAfter(streamingTimeout) {
176+ serverThread.enqueue(" hello" )
177+ while (source.getOffset.isEmpty) {
178+ Thread .sleep(10 )
179+ }
180+ val batch = source.getBatch(None , source.getOffset.get).as[String ]
181+ batch.collect()
182+ val numRowsMetric =
183+ batch.queryExecution.executedPlan.collectLeaves().head.metrics.get(" numOutputRows" )
184+ assert(numRowsMetric.nonEmpty)
185+ assert(numRowsMetric.get.value === 1 )
186+ source.stop()
187+ source = null
171188 }
172- val batch = source.getBatch(None , source.getOffset.get).as[String ]
173- batch.collect()
174- val numRowsMetric =
175- batch.queryExecution.executedPlan.collectLeaves().head.metrics.get(" numOutputRows" )
176- assert(numRowsMetric.nonEmpty)
177- assert(numRowsMetric.get.value === 1 )
178- source.stop()
179- source = null
189+ }
190+ }
191+
192+ def withAdditionalConf (additionalConf : Map [String , String ] = Map .empty)(f : () => Unit ): Unit = {
193+ val resetConfValues = mutable.Map [String , Option [String ]]()
194+ val conf = sqlContext.sparkSession.conf
195+ additionalConf.foreach(pair => {
196+ val value = if (conf.contains(pair._1)) {
197+ Some (conf.get(pair._1))
198+ } else {
199+ None
200+ }
201+ resetConfValues(pair._1) = value
202+ conf.set(pair._1, pair._2)
203+ })
204+
205+ f()
206+
207+ resetConfValues.foreach {
208+ case (key, Some (value)) => conf.set(key, value)
209+ case (key, None ) => conf.unset(key)
180210 }
181211 }
182212
0 commit comments