diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
index 34ee235d8b..eb7ce11c26 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
@@ -16,22 +16,15 @@
*/
package org.apache.spark.examples.sql.streaming;
-import org.apache.spark.api.java.function.FlatMapFunction;
-import org.apache.spark.api.java.function.MapFunction;
-import org.apache.spark.api.java.function.MapGroupsWithStateFunction;
import org.apache.spark.sql.*;
-import org.apache.spark.sql.streaming.GroupState;
-import org.apache.spark.sql.streaming.GroupStateTimeout;
import org.apache.spark.sql.streaming.StreamingQuery;
-import java.io.Serializable;
-import java.sql.Timestamp;
-import java.util.*;
+import static org.apache.spark.sql.functions.*;
/**
* Counts words in UTF8 encoded, '\n' delimited text received from the network.
*
- * Usage: JavaStructuredNetworkWordCount
+ * Usage: JavaStructuredSessionization
* and describe the TCP server that Structured Streaming
* would connect to receive data.
*
@@ -66,86 +59,20 @@ public final class JavaStructuredSessionization {
.option("includeTimestamp", true)
.load();
- FlatMapFunction linesToEvents =
- new FlatMapFunction() {
- @Override
- public Iterator call(LineWithTimestamp lineWithTimestamp) {
- ArrayList eventList = new ArrayList<>();
- for (String word : lineWithTimestamp.getLine().split(" ")) {
- eventList.add(new Event(word, lineWithTimestamp.getTimestamp()));
- }
- return eventList.iterator();
- }
- };
+ // Split the lines into words, retaining timestamps
+ // split() splits each line into an array, and explode() turns the array into multiple rows
+ // treat words as sessionId of events
+ Dataset events = lines
+ .selectExpr("explode(split(value, ' ')) AS sessionId", "timestamp AS eventTime");
- // Split the lines into words, treat words as sessionId of events
- Dataset events = lines
- .withColumnRenamed("value", "line")
- .as(Encoders.bean(LineWithTimestamp.class))
- .flatMap(linesToEvents, Encoders.bean(Event.class));
-
- // Sessionize the events. Track number of events, start and end timestamps of session, and
+ // Sessionize the events. Track number of events, start and end timestamps of session,
// and report session updates.
- //
- // Step 1: Define the state update function
- MapGroupsWithStateFunction stateUpdateFunc =
- new MapGroupsWithStateFunction() {
- @Override public SessionUpdate call(
- String sessionId, Iterator events, GroupState state) {
- // If timed out, then remove session and send final update
- if (state.hasTimedOut()) {
- SessionUpdate finalUpdate = new SessionUpdate(
- sessionId, state.get().calculateDuration(), state.get().getNumEvents(), true);
- state.remove();
- return finalUpdate;
-
- } else {
- // Find max and min timestamps in events
- long maxTimestampMs = Long.MIN_VALUE;
- long minTimestampMs = Long.MAX_VALUE;
- int numNewEvents = 0;
- while (events.hasNext()) {
- Event e = events.next();
- long timestampMs = e.getTimestamp().getTime();
- maxTimestampMs = Math.max(timestampMs, maxTimestampMs);
- minTimestampMs = Math.min(timestampMs, minTimestampMs);
- numNewEvents += 1;
- }
- SessionInfo updatedSession = new SessionInfo();
-
- // Update start and end timestamps in session
- if (state.exists()) {
- SessionInfo oldSession = state.get();
- updatedSession.setNumEvents(oldSession.numEvents + numNewEvents);
- updatedSession.setStartTimestampMs(oldSession.startTimestampMs);
- updatedSession.setEndTimestampMs(Math.max(oldSession.endTimestampMs, maxTimestampMs));
- } else {
- updatedSession.setNumEvents(numNewEvents);
- updatedSession.setStartTimestampMs(minTimestampMs);
- updatedSession.setEndTimestampMs(maxTimestampMs);
- }
- state.update(updatedSession);
- // Set timeout such that the session will be expired if no data received for 10 seconds
- state.setTimeoutDuration("10 seconds");
- return new SessionUpdate(
- sessionId, state.get().calculateDuration(), state.get().getNumEvents(), false);
- }
- }
- };
-
- // Step 2: Apply the state update function to the events streaming Dataset grouped by sessionId
- Dataset sessionUpdates = events
- .groupByKey(
- new MapFunction() {
- @Override public String call(Event event) {
- return event.getSessionId();
- }
- }, Encoders.STRING())
- .mapGroupsWithState(
- stateUpdateFunc,
- Encoders.bean(SessionInfo.class),
- Encoders.bean(SessionUpdate.class),
- GroupStateTimeout.ProcessingTimeTimeout());
+ Dataset sessionUpdates = events
+ .groupBy(session_window(col("eventTime"), "10 seconds").as("session"), col("sessionId"))
+ .agg(count("*").as("numEvents"))
+ .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
+ "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
+ "numEvents");
// Start running the query that prints the session updates to the console
StreamingQuery query = sessionUpdates
@@ -156,96 +83,4 @@ public final class JavaStructuredSessionization {
query.awaitTermination();
}
-
- /**
- * User-defined data type representing the raw lines with timestamps.
- */
- public static class LineWithTimestamp implements Serializable {
- private String line;
- private Timestamp timestamp;
-
- public Timestamp getTimestamp() { return timestamp; }
- public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; }
-
- public String getLine() { return line; }
- public void setLine(String sessionId) { this.line = sessionId; }
- }
-
- /**
- * User-defined data type representing the input events
- */
- public static class Event implements Serializable {
- private String sessionId;
- private Timestamp timestamp;
-
- public Event() { }
- public Event(String sessionId, Timestamp timestamp) {
- this.sessionId = sessionId;
- this.timestamp = timestamp;
- }
-
- public Timestamp getTimestamp() { return timestamp; }
- public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; }
-
- public String getSessionId() { return sessionId; }
- public void setSessionId(String sessionId) { this.sessionId = sessionId; }
- }
-
- /**
- * User-defined data type for storing a session information as state in mapGroupsWithState.
- */
- public static class SessionInfo implements Serializable {
- private int numEvents = 0;
- private long startTimestampMs = -1;
- private long endTimestampMs = -1;
-
- public int getNumEvents() { return numEvents; }
- public void setNumEvents(int numEvents) { this.numEvents = numEvents; }
-
- public long getStartTimestampMs() { return startTimestampMs; }
- public void setStartTimestampMs(long startTimestampMs) {
- this.startTimestampMs = startTimestampMs;
- }
-
- public long getEndTimestampMs() { return endTimestampMs; }
- public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; }
-
- public long calculateDuration() { return endTimestampMs - startTimestampMs; }
-
- @Override public String toString() {
- return "SessionInfo(numEvents = " + numEvents +
- ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")";
- }
- }
-
- /**
- * User-defined data type representing the update information returned by mapGroupsWithState.
- */
- public static class SessionUpdate implements Serializable {
- private String id;
- private long durationMs;
- private int numEvents;
- private boolean expired;
-
- public SessionUpdate() { }
-
- public SessionUpdate(String id, long durationMs, int numEvents, boolean expired) {
- this.id = id;
- this.durationMs = durationMs;
- this.numEvents = numEvents;
- this.expired = expired;
- }
-
- public String getId() { return id; }
- public void setId(String id) { this.id = id; }
-
- public long getDurationMs() { return durationMs; }
- public void setDurationMs(long durationMs) { this.durationMs = durationMs; }
-
- public int getNumEvents() { return numEvents; }
- public void setNumEvents(int numEvents) { this.numEvents = numEvents; }
-
- public boolean isExpired() { return expired; }
- public void setExpired(boolean expired) { this.expired = expired; }
- }
}
diff --git a/examples/src/main/python/sql/streaming/structured_sessionization.py b/examples/src/main/python/sql/streaming/structured_sessionization.py
new file mode 100644
index 0000000000..78cb406650
--- /dev/null
+++ b/examples/src/main/python/sql/streaming/structured_sessionization.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+r"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the network over a
+ sliding window of configurable duration.
+
+ Usage: structured_sessionization.py
+ and describe the TCP server that Structured Streaming
+ would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit
+ examples/src/main/python/sql/streaming/structured_sessionization.py
+ localhost 9999`
+"""
+import sys
+
+from pyspark.sql import SparkSession
+from pyspark.sql.functions import explode
+from pyspark.sql.functions import split
+from pyspark.sql.functions import count, session_window
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3 and len(sys.argv) != 2:
+ msg = "Usage: structured_sessionization.py "
+ print(msg, file=sys.stderr)
+ sys.exit(-1)
+
+ host = sys.argv[1]
+ port = int(sys.argv[2])
+
+ spark = SparkSession\
+ .builder\
+ .appName("StructuredSessionization")\
+ .getOrCreate()
+
+ # Create DataFrame representing the stream of input lines from connection to host:port
+ lines = spark\
+ .readStream\
+ .format('socket')\
+ .option('host', host)\
+ .option('port', port)\
+ .option('includeTimestamp', 'true')\
+ .load()
+
+ # Split the lines into words, retaining timestamps
+ # split() splits each line into an array, and explode() turns the array into multiple rows
+ # treat words as sessionId of events
+ events = lines.select(
+ explode(split(lines.value, ' ')).alias('sessionId'),
+ lines.timestamp.alias('eventTime')
+ )
+
+ # Group the data by window and word and compute the count of each group
+ windowedCounts = events \
+ .groupBy(session_window(events.eventTime, "10 seconds").alias('session'),
+ events.sessionId) \
+ .agg(count("*").alias("numEvents")) \
+ .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
+ "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
+ "numEvents")
+
+ # Start running the query that prints the session updates to the console
+ query = windowedCounts\
+ .writeStream\
+ .outputMode('update')\
+ .format('console')\
+ .start()
+
+ query.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala
index 29dbb0d95c..63e8dd9c7b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala
@@ -18,16 +18,14 @@
// scalastyle:off println
package org.apache.spark.examples.sql.streaming
-import java.sql.Timestamp
-
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.streaming._
+import org.apache.spark.sql.functions.{count, session_window}
/**
* Counts words in UTF8 encoded, '\n' delimited text received from the network.
*
- * Usage: MapGroupsWithState
+ * Usage: StructuredSessionization
* and describe the TCP server that Structured Streaming
* would connect to receive data.
*
@@ -63,46 +61,20 @@ object StructuredSessionization {
.option("includeTimestamp", true)
.load()
- // Split the lines into words, treat words as sessionId of events
+ // Split the lines into words, retaining timestamps
+ // split() splits each line into an array, and explode() turns the array into multiple rows
+ // treat words as sessionId of events
val events = lines
- .as[(String, Timestamp)]
- .flatMap { case (line, timestamp) =>
- line.split(" ").map(word => Event(sessionId = word, timestamp))
- }
+ .selectExpr("explode(split(value, ' ')) AS sessionId", "timestamp AS eventTime")
// Sessionize the events. Track number of events, start and end timestamps of session,
// and report session updates.
val sessionUpdates = events
- .groupByKey(event => event.sessionId)
- .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) {
-
- case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) =>
-
- // If timed out, then remove session and send final update
- if (state.hasTimedOut) {
- val finalUpdate =
- SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true)
- state.remove()
- finalUpdate
- } else {
- // Update start and end timestamps in session
- val timestamps = events.map(_.timestamp.getTime).toSeq
- val updatedSession = if (state.exists) {
- val oldSession = state.get
- SessionInfo(
- oldSession.numEvents + timestamps.size,
- oldSession.startTimestampMs,
- math.max(oldSession.endTimestampMs, timestamps.max))
- } else {
- SessionInfo(timestamps.size, timestamps.min, timestamps.max)
- }
- state.update(updatedSession)
-
- // Set timeout such that the session will be expired if no data received for 10 seconds
- state.setTimeoutDuration("10 seconds")
- SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false)
- }
- }
+ .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
+ .agg(count("*").as("numEvents"))
+ .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
+ "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
+ "numEvents")
// Start running the query that prints the session updates to the console
val query = sessionUpdates
@@ -114,38 +86,4 @@ object StructuredSessionization {
query.awaitTermination()
}
}
-/** User-defined data type representing the input events */
-case class Event(sessionId: String, timestamp: Timestamp)
-
-/**
- * User-defined data type for storing a session information as state in mapGroupsWithState.
- *
- * @param numEvents total number of events received in the session
- * @param startTimestampMs timestamp of first event received in the session when it started
- * @param endTimestampMs timestamp of last event received in the session before it expired
- */
-case class SessionInfo(
- numEvents: Int,
- startTimestampMs: Long,
- endTimestampMs: Long) {
-
- /** Duration of the session, between the first and last events */
- def durationMs: Long = endTimestampMs - startTimestampMs
-}
-
-/**
- * User-defined data type representing the update information returned by mapGroupsWithState.
- *
- * @param id Id of the session
- * @param durationMs Duration the session was active, that is, from first event to its expiry
- * @param numEvents Number of events received by the session while it was active
- * @param expired Is the session active or expired
- */
-case class SessionUpdate(
- id: String,
- durationMs: Long,
- numEvents: Int,
- expired: Boolean)
-
// scalastyle:on println
-