pyspark.sql.GroupedData.transformWithStateInPandas#

GroupedData.transformWithStateInPandas(statefulProcessor, outputStructType, outputMode, timeMode, initialState=None, eventTimeColumnName='')#

Invokes methods defined in the stateful processor used in arbitrary state API v2. It requires protobuf, pandas and pyarrow as dependencies to process input/state data. We allow the user to act on per-group set of input rows along with keyed state and the user can choose to output/return 0 or more rows.

For a streaming dataframe, we will repeatedly invoke the interface methods for new rows in each trigger and the user’s state/state variables will be stored persistently across invocations.

The statefulProcessor should be a Python class that implements the interface defined in StatefulProcessor.

The outputStructType should be a StructType describing the schema of all elements in the returned value, pandas.DataFrame. The column labels of all elements in returned pandas.DataFrame must either match the field names in the defined schema if specified as strings, or match the field data types by position if not strings, e.g. integer indices.

The size of each pandas.DataFrame in both the input and output can be arbitrary. The number of pandas.DataFrame in both the input and output can also be arbitrary.

New in version 4.0.0.

Parameters
statefulProcessorpyspark.sql.streaming.stateful_processor.StatefulProcessor

Instance of StatefulProcessor whose functions will be invoked by the operator.

outputStructTypepyspark.sql.types.DataType or str

The type of the output records. The value can be either a pyspark.sql.types.DataType object or a DDL-formatted type string.

outputModestr

The output mode of the stateful processor.

timeModestr

The time mode semantics of the stateful processor for timers and TTL.

initialStatepyspark.sql.GroupedData

Optional. The grouped dataframe as initial states used for initialization of state variables in the first batch.

Notes

This function requires a full shuffle.

Examples

>>> from typing import Iterator
...
>>> import pandas as pd 
...
>>> from pyspark.sql import Row
>>> from pyspark.sql.functions import col, split
>>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
>>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType
...
>>> spark.conf.set("spark.sql.streaming.stateStore.providerClass",
...     "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
... # Below is a simple example to find erroneous sensors from temperature sensor data. The
... # processor returns a count of total readings, while keeping erroneous reading counts
... # in streaming state. A violation is defined when the temperature is above 100.
... # The input data is a DataFrame with the following schema:
... #    `id: string, temperature: long`.
... # The output schema and state schema are defined as below.
>>> output_schema = StructType([
...     StructField("id", StringType(), True),
...     StructField("count", IntegerType(), True)
... ])
>>> state_schema = StructType([
...     StructField("value", IntegerType(), True)
... ])
>>> class SimpleStatefulProcessor(StatefulProcessor):
...     def init(self, handle: StatefulProcessorHandle):
...         self.num_violations_state = handle.getValueState("numViolations", state_schema)
...
...     def handleInputRows(self, key, rows):
...         new_violations = 0
...         count = 0
...         exists = self.num_violations_state.exists()
...         if exists:
...             existing_violations_row = self.num_violations_state.get()
...             existing_violations = existing_violations_row[0]
...         else:
...             existing_violations = 0
...         for pdf in rows:
...             pdf_count = pdf.count()
...             count += pdf_count.get('temperature')
...             violations_pdf = pdf.loc[pdf['temperature'] > 100]
...             new_violations += violations_pdf.count().get('temperature')
...         updated_violations = new_violations + existing_violations
...         self.num_violations_state.update((updated_violations,))
...         yield pd.DataFrame({'id': key, 'count': count})
...
...     def close(self) -> None:
...         pass

Input DataFrame: +—+———–+ | id|temperature| +—+———–+ | 0| 123| | 0| 23| | 1| 33| | 1| 188| | 1| 88| +—+———–+

>>> df.groupBy("value").transformWithStateInPandas(statefulProcessor =
...     SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update",
...     timeMode="None") 

Output DataFrame: +—+—–+ | id|count| +—+—–+ | 0| 2| | 1| 3| +—+—–+