pyspark.sql.GroupedData.transformWithStateInPandas#
- GroupedData.transformWithStateInPandas(statefulProcessor, outputStructType, outputMode, timeMode, initialState=None, eventTimeColumnName='')#
Invokes methods defined in the stateful processor used in arbitrary state API v2. It requires protobuf, pandas and pyarrow as dependencies to process input/state data. We allow the user to act on per-group set of input rows along with keyed state and the user can choose to output/return 0 or more rows.
For a streaming dataframe, we will repeatedly invoke the interface methods for new rows in each trigger and the user’s state/state variables will be stored persistently across invocations.
The statefulProcessor should be a Python class that implements the interface defined in
StatefulProcessor
.The outputStructType should be a
StructType
describing the schema of all elements in the returned value, pandas.DataFrame. The column labels of all elements in returned pandas.DataFrame must either match the field names in the defined schema if specified as strings, or match the field data types by position if not strings, e.g. integer indices.The size of each pandas.DataFrame in both the input and output can be arbitrary. The number of pandas.DataFrame in both the input and output can also be arbitrary.
New in version 4.0.0.
- Parameters
- statefulProcessor
pyspark.sql.streaming.stateful_processor.StatefulProcessor
Instance of StatefulProcessor whose functions will be invoked by the operator.
- outputStructType
pyspark.sql.types.DataType
or str The type of the output records. The value can be either a
pyspark.sql.types.DataType
object or a DDL-formatted type string.- outputModestr
The output mode of the stateful processor.
- timeModestr
The time mode semantics of the stateful processor for timers and TTL.
- initialState
pyspark.sql.GroupedData
Optional. The grouped dataframe as initial states used for initialization of state variables in the first batch.
- statefulProcessor
Notes
This function requires a full shuffle.
Examples
>>> from typing import Iterator ... >>> import pandas as pd ... >>> from pyspark.sql import Row >>> from pyspark.sql.functions import col, split >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType ... >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass", ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") ... # Below is a simple example to find erroneous sensors from temperature sensor data. The ... # processor returns a count of total readings, while keeping erroneous reading counts ... # in streaming state. A violation is defined when the temperature is above 100. ... # The input data is a DataFrame with the following schema: ... # `id: string, temperature: long`. ... # The output schema and state schema are defined as below. >>> output_schema = StructType([ ... StructField("id", StringType(), True), ... StructField("count", IntegerType(), True) ... ]) >>> state_schema = StructType([ ... StructField("value", IntegerType(), True) ... ]) >>> class SimpleStatefulProcessor(StatefulProcessor): ... def init(self, handle: StatefulProcessorHandle): ... self.num_violations_state = handle.getValueState("numViolations", state_schema) ... ... def handleInputRows(self, key, rows): ... new_violations = 0 ... count = 0 ... exists = self.num_violations_state.exists() ... if exists: ... existing_violations_row = self.num_violations_state.get() ... existing_violations = existing_violations_row[0] ... else: ... existing_violations = 0 ... for pdf in rows: ... pdf_count = pdf.count() ... count += pdf_count.get('temperature') ... violations_pdf = pdf.loc[pdf['temperature'] > 100] ... new_violations += violations_pdf.count().get('temperature') ... updated_violations = new_violations + existing_violations ... self.num_violations_state.update((updated_violations,)) ... yield pd.DataFrame({'id': key, 'count': count}) ... ... def close(self) -> None: ... pass
Input DataFrame: +—+———–+ | id|temperature| +—+———–+ | 0| 123| | 0| 23| | 1| 33| | 1| 188| | 1| 88| +—+———–+
>>> df.groupBy("value").transformWithStateInPandas(statefulProcessor = ... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", ... timeMode="None")
Output DataFrame: +—+—–+ | id|count| +—+—–+ | 0| 2| | 1| 3| +—+—–+