diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java index 34ee235d8b..eb7ce11c26 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -16,22 +16,15 @@ */ package org.apache.spark.examples.sql.streaming; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.MapFunction; -import org.apache.spark.api.java.function.MapGroupsWithStateFunction; import org.apache.spark.sql.*; -import org.apache.spark.sql.streaming.GroupState; -import org.apache.spark.sql.streaming.GroupStateTimeout; import org.apache.spark.sql.streaming.StreamingQuery; -import java.io.Serializable; -import java.sql.Timestamp; -import java.util.*; +import static org.apache.spark.sql.functions.*; /** * Counts words in UTF8 encoded, '\n' delimited text received from the network. *

- * Usage: JavaStructuredNetworkWordCount + * Usage: JavaStructuredSessionization * and describe the TCP server that Structured Streaming * would connect to receive data. *

@@ -66,86 +59,20 @@ public final class JavaStructuredSessionization { .option("includeTimestamp", true) .load(); - FlatMapFunction linesToEvents = - new FlatMapFunction() { - @Override - public Iterator call(LineWithTimestamp lineWithTimestamp) { - ArrayList eventList = new ArrayList<>(); - for (String word : lineWithTimestamp.getLine().split(" ")) { - eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); - } - return eventList.iterator(); - } - }; + // Split the lines into words, retaining timestamps + // split() splits each line into an array, and explode() turns the array into multiple rows + // treat words as sessionId of events + Dataset events = lines + .selectExpr("explode(split(value, ' ')) AS sessionId", "timestamp AS eventTime"); - // Split the lines into words, treat words as sessionId of events - Dataset events = lines - .withColumnRenamed("value", "line") - .as(Encoders.bean(LineWithTimestamp.class)) - .flatMap(linesToEvents, Encoders.bean(Event.class)); - - // Sessionize the events. Track number of events, start and end timestamps of session, and + // Sessionize the events. Track number of events, start and end timestamps of session, // and report session updates. - // - // Step 1: Define the state update function - MapGroupsWithStateFunction stateUpdateFunc = - new MapGroupsWithStateFunction() { - @Override public SessionUpdate call( - String sessionId, Iterator events, GroupState state) { - // If timed out, then remove session and send final update - if (state.hasTimedOut()) { - SessionUpdate finalUpdate = new SessionUpdate( - sessionId, state.get().calculateDuration(), state.get().getNumEvents(), true); - state.remove(); - return finalUpdate; - - } else { - // Find max and min timestamps in events - long maxTimestampMs = Long.MIN_VALUE; - long minTimestampMs = Long.MAX_VALUE; - int numNewEvents = 0; - while (events.hasNext()) { - Event e = events.next(); - long timestampMs = e.getTimestamp().getTime(); - maxTimestampMs = Math.max(timestampMs, maxTimestampMs); - minTimestampMs = Math.min(timestampMs, minTimestampMs); - numNewEvents += 1; - } - SessionInfo updatedSession = new SessionInfo(); - - // Update start and end timestamps in session - if (state.exists()) { - SessionInfo oldSession = state.get(); - updatedSession.setNumEvents(oldSession.numEvents + numNewEvents); - updatedSession.setStartTimestampMs(oldSession.startTimestampMs); - updatedSession.setEndTimestampMs(Math.max(oldSession.endTimestampMs, maxTimestampMs)); - } else { - updatedSession.setNumEvents(numNewEvents); - updatedSession.setStartTimestampMs(minTimestampMs); - updatedSession.setEndTimestampMs(maxTimestampMs); - } - state.update(updatedSession); - // Set timeout such that the session will be expired if no data received for 10 seconds - state.setTimeoutDuration("10 seconds"); - return new SessionUpdate( - sessionId, state.get().calculateDuration(), state.get().getNumEvents(), false); - } - } - }; - - // Step 2: Apply the state update function to the events streaming Dataset grouped by sessionId - Dataset sessionUpdates = events - .groupByKey( - new MapFunction() { - @Override public String call(Event event) { - return event.getSessionId(); - } - }, Encoders.STRING()) - .mapGroupsWithState( - stateUpdateFunc, - Encoders.bean(SessionInfo.class), - Encoders.bean(SessionUpdate.class), - GroupStateTimeout.ProcessingTimeTimeout()); + Dataset sessionUpdates = events + .groupBy(session_window(col("eventTime"), "10 seconds").as("session"), col("sessionId")) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents"); // Start running the query that prints the session updates to the console StreamingQuery query = sessionUpdates @@ -156,96 +83,4 @@ public final class JavaStructuredSessionization { query.awaitTermination(); } - - /** - * User-defined data type representing the raw lines with timestamps. - */ - public static class LineWithTimestamp implements Serializable { - private String line; - private Timestamp timestamp; - - public Timestamp getTimestamp() { return timestamp; } - public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } - - public String getLine() { return line; } - public void setLine(String sessionId) { this.line = sessionId; } - } - - /** - * User-defined data type representing the input events - */ - public static class Event implements Serializable { - private String sessionId; - private Timestamp timestamp; - - public Event() { } - public Event(String sessionId, Timestamp timestamp) { - this.sessionId = sessionId; - this.timestamp = timestamp; - } - - public Timestamp getTimestamp() { return timestamp; } - public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } - - public String getSessionId() { return sessionId; } - public void setSessionId(String sessionId) { this.sessionId = sessionId; } - } - - /** - * User-defined data type for storing a session information as state in mapGroupsWithState. - */ - public static class SessionInfo implements Serializable { - private int numEvents = 0; - private long startTimestampMs = -1; - private long endTimestampMs = -1; - - public int getNumEvents() { return numEvents; } - public void setNumEvents(int numEvents) { this.numEvents = numEvents; } - - public long getStartTimestampMs() { return startTimestampMs; } - public void setStartTimestampMs(long startTimestampMs) { - this.startTimestampMs = startTimestampMs; - } - - public long getEndTimestampMs() { return endTimestampMs; } - public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } - - public long calculateDuration() { return endTimestampMs - startTimestampMs; } - - @Override public String toString() { - return "SessionInfo(numEvents = " + numEvents + - ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; - } - } - - /** - * User-defined data type representing the update information returned by mapGroupsWithState. - */ - public static class SessionUpdate implements Serializable { - private String id; - private long durationMs; - private int numEvents; - private boolean expired; - - public SessionUpdate() { } - - public SessionUpdate(String id, long durationMs, int numEvents, boolean expired) { - this.id = id; - this.durationMs = durationMs; - this.numEvents = numEvents; - this.expired = expired; - } - - public String getId() { return id; } - public void setId(String id) { this.id = id; } - - public long getDurationMs() { return durationMs; } - public void setDurationMs(long durationMs) { this.durationMs = durationMs; } - - public int getNumEvents() { return numEvents; } - public void setNumEvents(int numEvents) { this.numEvents = numEvents; } - - public boolean isExpired() { return expired; } - public void setExpired(boolean expired) { this.expired = expired; } - } } diff --git a/examples/src/main/python/sql/streaming/structured_sessionization.py b/examples/src/main/python/sql/streaming/structured_sessionization.py new file mode 100644 index 0000000000..78cb406650 --- /dev/null +++ b/examples/src/main/python/sql/streaming/structured_sessionization.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +r""" + Counts words in UTF8 encoded, '\n' delimited text received from the network over a + sliding window of configurable duration. + + Usage: structured_sessionization.py + and describe the TCP server that Structured Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit + examples/src/main/python/sql/streaming/structured_sessionization.py + localhost 9999` +""" +import sys + +from pyspark.sql import SparkSession +from pyspark.sql.functions import explode +from pyspark.sql.functions import split +from pyspark.sql.functions import count, session_window + +if __name__ == "__main__": + if len(sys.argv) != 3 and len(sys.argv) != 2: + msg = "Usage: structured_sessionization.py " + print(msg, file=sys.stderr) + sys.exit(-1) + + host = sys.argv[1] + port = int(sys.argv[2]) + + spark = SparkSession\ + .builder\ + .appName("StructuredSessionization")\ + .getOrCreate() + + # Create DataFrame representing the stream of input lines from connection to host:port + lines = spark\ + .readStream\ + .format('socket')\ + .option('host', host)\ + .option('port', port)\ + .option('includeTimestamp', 'true')\ + .load() + + # Split the lines into words, retaining timestamps + # split() splits each line into an array, and explode() turns the array into multiple rows + # treat words as sessionId of events + events = lines.select( + explode(split(lines.value, ' ')).alias('sessionId'), + lines.timestamp.alias('eventTime') + ) + + # Group the data by window and word and compute the count of each group + windowedCounts = events \ + .groupBy(session_window(events.eventTime, "10 seconds").alias('session'), + events.sessionId) \ + .agg(count("*").alias("numEvents")) \ + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + # Start running the query that prints the session updates to the console + query = windowedCounts\ + .writeStream\ + .outputMode('update')\ + .format('console')\ + .start() + + query.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala index 29dbb0d95c..63e8dd9c7b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala @@ -18,16 +18,14 @@ // scalastyle:off println package org.apache.spark.examples.sql.streaming -import java.sql.Timestamp - import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.functions.{count, session_window} /** * Counts words in UTF8 encoded, '\n' delimited text received from the network. * - * Usage: MapGroupsWithState + * Usage: StructuredSessionization * and describe the TCP server that Structured Streaming * would connect to receive data. * @@ -63,46 +61,20 @@ object StructuredSessionization { .option("includeTimestamp", true) .load() - // Split the lines into words, treat words as sessionId of events + // Split the lines into words, retaining timestamps + // split() splits each line into an array, and explode() turns the array into multiple rows + // treat words as sessionId of events val events = lines - .as[(String, Timestamp)] - .flatMap { case (line, timestamp) => - line.split(" ").map(word => Event(sessionId = word, timestamp)) - } + .selectExpr("explode(split(value, ' ')) AS sessionId", "timestamp AS eventTime") // Sessionize the events. Track number of events, start and end timestamps of session, // and report session updates. val sessionUpdates = events - .groupByKey(event => event.sessionId) - .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { - - case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => - - // If timed out, then remove session and send final update - if (state.hasTimedOut) { - val finalUpdate = - SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) - state.remove() - finalUpdate - } else { - // Update start and end timestamps in session - val timestamps = events.map(_.timestamp.getTime).toSeq - val updatedSession = if (state.exists) { - val oldSession = state.get - SessionInfo( - oldSession.numEvents + timestamps.size, - oldSession.startTimestampMs, - math.max(oldSession.endTimestampMs, timestamps.max)) - } else { - SessionInfo(timestamps.size, timestamps.min, timestamps.max) - } - state.update(updatedSession) - - // Set timeout such that the session will be expired if no data received for 10 seconds - state.setTimeoutDuration("10 seconds") - SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) - } - } + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") // Start running the query that prints the session updates to the console val query = sessionUpdates @@ -114,38 +86,4 @@ object StructuredSessionization { query.awaitTermination() } } -/** User-defined data type representing the input events */ -case class Event(sessionId: String, timestamp: Timestamp) - -/** - * User-defined data type for storing a session information as state in mapGroupsWithState. - * - * @param numEvents total number of events received in the session - * @param startTimestampMs timestamp of first event received in the session when it started - * @param endTimestampMs timestamp of last event received in the session before it expired - */ -case class SessionInfo( - numEvents: Int, - startTimestampMs: Long, - endTimestampMs: Long) { - - /** Duration of the session, between the first and last events */ - def durationMs: Long = endTimestampMs - startTimestampMs -} - -/** - * User-defined data type representing the update information returned by mapGroupsWithState. - * - * @param id Id of the session - * @param durationMs Duration the session was active, that is, from first event to its expiry - * @param numEvents Number of events received by the session while it was active - * @param expired Is the session active or expired - */ -case class SessionUpdate( - id: String, - durationMs: Long, - numEvents: Int, - expired: Boolean) - // scalastyle:on println -