Skip to content

Commit 34449e4

Browse files
committed
Addressed many comments
1 parent 8be63de commit 34449e4

File tree

13 files changed

+249
-213
lines changed

13 files changed

+249
-213
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@ object UnsupportedOperationChecker {
4646
"Queries without streaming sources cannot be executed with writeStream.start()")(plan)
4747
}
4848

49+
/** Collect all the streaming aggregates in a sub plan */
50+
def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = {
51+
subplan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
52+
}
53+
4954
// Disallow multiple streaming aggregations
50-
val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
55+
val aggregates = collectStreamingAggregates(plan)
5156

5257
if (aggregates.size > 1) {
5358
throwError(
@@ -114,6 +119,10 @@ object UnsupportedOperationChecker {
114119
case _: InsertIntoTable =>
115120
throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets")
116121

122+
case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
123+
throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
124+
"streaming DataFrame/Dataset")
125+
117126
case Join(left, right, joinType, _) =>
118127

119128
joinType match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,12 @@ case class MapGroups(
314314
child: LogicalPlan) extends UnaryNode with ObjectProducer
315315

316316
/** Internal class representing State */
317-
trait LogicalState[S]
317+
trait LogicalKeyedState[S]
318318

319319
/** Factory for constructing new `MapGroupsWithState` nodes. */
320320
object MapGroupsWithState {
321321
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
322-
func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any],
322+
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
323323
groupingAttributes: Seq[Attribute],
324324
dataAttributes: Seq[Attribute],
325325
child: LogicalPlan): LogicalPlan = {
@@ -352,7 +352,7 @@ object MapGroupsWithState {
352352
* @param stateSerializer used to serialize updated state after calling `func`
353353
*/
354354
case class MapGroupsWithState(
355-
func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any],
355+
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
356356
keyDeserializer: Expression,
357357
valueDeserializer: Expression,
358358
groupingAttributes: Seq[Attribute],

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException
2222
import org.apache.spark.sql.catalyst.dsl.expressions._
2323
import org.apache.spark.sql.catalyst.dsl.plans._
2424
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
25-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression}
25+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
2727
import org.apache.spark.sql.catalyst.plans._
28-
import org.apache.spark.sql.catalyst.plans.logical._
28+
import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
2929
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
3030
import org.apache.spark.sql.streaming.OutputMode
31-
import org.apache.spark.sql.types.IntegerType
31+
import org.apache.spark.sql.types.{IntegerType, LongType}
3232

3333
/** A dummy command for testing unsupported operations. */
3434
case class DummyCommand() extends Command
@@ -111,6 +111,25 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
111111
outputMode = Complete,
112112
expectedMsgs = Seq("distinct aggregation"))
113113

114+
// MapGroupsWithState: Not supported after a streaming aggregation
115+
val att = new AttributeReference(name = "a", dataType = LongType)()
116+
assertSupportedInStreamingPlan(
117+
"mapGroupsWithState - mapGroupsWithState on batch relation",
118+
MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation),
119+
outputMode = Append)
120+
121+
assertSupportedInStreamingPlan(
122+
"mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
123+
MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
124+
outputMode = Append)
125+
126+
assertNotSupportedInStreamingPlan(
127+
"mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
128+
MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
129+
Aggregate(Nil, aggExprs("c"), streamRelation)),
130+
outputMode = Complete,
131+
expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
132+
114133
// Inner joins: Stream-stream not supported
115134
testBinaryOperationInStreamingPlan(
116135
"inner join",

sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,15 @@ class KeyValueGroupedDataset[K, V] private[sql](
243243
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
244244
* constraints of their cluster.
245245
*
246-
* @see [[State]] for more details of how to update/remove state in the function.
246+
* @see [[KeyedState]] for more details of how to update/remove state in the function.
247247
* @since 2.1.1
248248
*/
249249
@Experimental
250250
@InterfaceStability.Evolving
251251
def mapGroupsWithState[STATE: Encoder, OUT: Encoder](
252-
func: (K, Iterator[V], State[STATE]) => OUT): Dataset[OUT] = {
253-
val f = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(func(key, it, s))
254-
flatMapGroupsWithState[STATE, OUT](f)
252+
func: (K, Iterator[V], KeyedState[STATE]) => OUT): Dataset[OUT] = {
253+
flatMapGroupsWithState[STATE, OUT](
254+
(key: K, it: Iterator[V], s: KeyedState[STATE]) => Iterator(func(key, it, s)))
255255
}
256256

257257
/**
@@ -279,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
279279
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
280280
* constraints of their cluster.
281281
*
282-
* @see [[State]] for more details of how to update/remove state in the function.
282+
* @see [[KeyedState]] for more details of how to update/remove state in the function.
283283
* @since 2.1.1
284284
*/
285285
@Experimental
@@ -288,8 +288,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
288288
func: MapGroupsWithStateFunction[K, V, STATE, OUT],
289289
stateEncoder: Encoder[STATE],
290290
outputEncoder: Encoder[OUT]): Dataset[OUT] = {
291-
val f = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(func.call(key, it.asJava, s))
292-
flatMapGroupsWithState[STATE, OUT](f)(stateEncoder, outputEncoder)
291+
flatMapGroupsWithState[STATE, OUT](
292+
(key: K, it: Iterator[V], s: KeyedState[STATE]) => Iterator(func.call(key, it.asJava, s))
293+
)(stateEncoder, outputEncoder)
293294
}
294295

295296

@@ -318,17 +319,17 @@ class KeyValueGroupedDataset[K, V] private[sql](
318319
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
319320
* constraints of their cluster.
320321
*
321-
* @see [[State]] for more details of how to update/remove state in the function.
322+
* @see [[KeyedState]] for more details of how to update/remove state in the function.
322323
* @since 2.1.1
323324
*/
324325
@Experimental
325326
@InterfaceStability.Evolving
326327
def flatMapGroupsWithState[STATE: Encoder, OUT: Encoder](
327-
func: (K, Iterator[V], State[STATE]) => Iterator[OUT]): Dataset[OUT] = {
328+
func: (K, Iterator[V], KeyedState[STATE]) => Iterator[OUT]): Dataset[OUT] = {
328329
Dataset[OUT](
329330
sparkSession,
330331
MapGroupsWithState[K, V, STATE, OUT](
331-
func.asInstanceOf[(Any, Iterator[Any], LogicalState[Any]) => Iterator[Any]],
332+
func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
332333
groupingAttributes,
333334
dataAttributes,
334335
logicalPlan))
@@ -359,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
359360
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
360361
* constraints of their cluster.
361362
*
362-
* @see [[State]] for more details of how to update/remove state in the function.
363+
* @see [[KeyedState]] for more details of how to update/remove state in the function.
363364
* @since 2.1.1
364365
*/
365366
@Experimental
@@ -368,8 +369,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
368369
func: FlatMapGroupsWithStateFunction[K, V, STATE, OUT],
369370
stateEncoder: Encoder[STATE],
370371
outputEncoder: Encoder[OUT]): Dataset[OUT] = {
371-
val f = (key: K, it: Iterator[V], s: State[STATE]) => func.call(key, it.asJava, s).asScala
372-
flatMapGroupsWithState[STATE, OUT](f)(stateEncoder, outputEncoder)
372+
flatMapGroupsWithState[STATE, OUT](
373+
(key: K, it: Iterator[V], s: KeyedState[STATE]) => func.call(key, it.asJava, s).asScala
374+
)(stateEncoder, outputEncoder)
373375
}
374376

375377
/**
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.annotation.{Experimental, InterfaceStability}
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
22+
23+
/**
24+
* :: Experimental ::
25+
*
26+
* Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
27+
* `flatMapGroupsWithState` operations on
28+
* [[org.apache.spark.sql.KeyValueGroupedDataset KeyValueGroupedDataset]].
29+
*
30+
* Important points to note.
31+
* - State can be `null`. So updating the state to null is not same as removing the state.
32+
* - Operations on state are not threadsafe. This is to avoid memory barriers.
33+
* - If the `remove()` is called, then `exists()` will return `false`, and
34+
* `getOption()` will return `None`.
35+
* - After that `update(newState)` is called, then `exists()` will return `true`,
36+
* and `getOption()` will return `Some(...)`.
37+
*
38+
* Scala example of using `KeyedState`:
39+
* {{{
40+
* // A mapping function that maintains an integer state for string keys and returns a string.
41+
* def mappingFunction(key: String, value: Iterable[Int], state: KeyedState[Int]): Option[String]= {
42+
* // Check if state exists
43+
* if (state.exists) {
44+
* val existingState = state.get // Get the existing state
45+
* val shouldRemove = ... // Decide whether to remove the state
46+
* if (shouldRemove) {
47+
* state.remove() // Remove the state
48+
* } else {
49+
* val newState = ...
50+
* state.update(newState) // Set the new state
51+
* }
52+
* } else {
53+
* val initialState = ...
54+
* state.update(initialState) // Set the initial state
55+
* }
56+
* ... // return something
57+
* }
58+
*
59+
* }}}
60+
*
61+
* Java example of using `KeyedState`:
62+
* {{{
63+
* // A mapping function that maintains an integer state for string keys and returns a string.
64+
* MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
65+
* new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
66+
*
67+
* @Override
68+
* public String call(String key, Optional<Integer> value, KeyedState<Integer> state) {
69+
* if (state.exists()) {
70+
* int existingState = state.get(); // Get the existing state
71+
* boolean shouldRemove = ...; // Decide whether to remove the state
72+
* if (shouldRemove) {
73+
* state.remove(); // Remove the state
74+
* } else {
75+
* int newState = ...;
76+
* state.update(newState); // Set the new state
77+
* }
78+
* } else {
79+
* int initialState = ...; // Set the initial state
80+
* state.update(initialState);
81+
* }
82+
* ... // return something
83+
* }
84+
* };
85+
* }}}
86+
*
87+
* @tparam S User-defined type of the state to be stored for each key. Must be encodable into
88+
* Spark SQL types (see [[Encoder]] for more details).
89+
* @since 2.1.1
90+
*/
91+
@Experimental
92+
@InterfaceStability.Evolving
93+
trait KeyedState[S] extends LogicalKeyedState[S] {
94+
95+
/** Whether state exists or not. */
96+
def exists: Boolean
97+
98+
/** Get the state object if it is defined, otherwise throws NoSuchElementException. */
99+
def get: S
100+
101+
/**
102+
* Update the value of the state. Note that null is a valid value, and does not signify removing
103+
* of the state.
104+
*/
105+
def update(newState: S): Unit
106+
107+
/** Remove this keyed state. */
108+
def remove(): Unit
109+
110+
/** (scala friendly) Get the state object as an [[Option]]. */
111+
@inline final def getOption: Option[S] = if (exists) Some(get) else None
112+
113+
@inline final override def toString: String = {
114+
getOption.map { _.toString }.getOrElse("<undefined>")
115+
}
116+
}

sql/core/src/main/scala/org/apache/spark/sql/State.scala

Lines changed: 0 additions & 101 deletions
This file was deleted.

0 commit comments

Comments
 (0)