#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import ClassVar, Type, Dict, List, Optional, Union, cast, TYPE_CHECKING
from pyspark.util import local_connect_and_auth
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
from pyspark.errors import PySparkRuntimeError
if TYPE_CHECKING:
from pyspark.resource import ResourceInformation
[docs]class TaskContext:
"""
Contextual information about a task which can be read or mutated during
execution. To access the TaskContext for a running task, use:
:meth:`TaskContext.get`.
.. versionadded:: 2.2.0
Examples
--------
>>> from pyspark import TaskContext
Get a task context instance from :class:`RDD`.
>>> spark.sparkContext.setLocalProperty("key1", "value")
>>> taskcontext = spark.sparkContext.parallelize([1]).map(lambda _: TaskContext.get()).first()
>>> isinstance(taskcontext.attemptNumber(), int)
True
>>> isinstance(taskcontext.partitionId(), int)
True
>>> isinstance(taskcontext.stageId(), int)
True
>>> isinstance(taskcontext.taskAttemptId(), int)
True
>>> taskcontext.getLocalProperty("key1")
'value'
>>> isinstance(taskcontext.cpus(), int)
True
Get a task context instance from a dataframe via Python UDF.
>>> from pyspark.sql import Row
>>> from pyspark.sql.functions import udf
>>> @udf("STRUCT<anum: INT, partid: INT, stageid: INT, taskaid: INT, prop: STRING, cpus: INT>")
... def taskcontext_as_row():
... taskcontext = TaskContext.get()
... return Row(
... anum=taskcontext.attemptNumber(),
... partid=taskcontext.partitionId(),
... stageid=taskcontext.stageId(),
... taskaid=taskcontext.taskAttemptId(),
... prop=taskcontext.getLocalProperty("key2"),
... cpus=taskcontext.cpus())
...
>>> spark.sparkContext.setLocalProperty("key2", "value")
>>> [(anum, partid, stageid, taskaid, prop, cpus)] = (
... spark.range(1).select(taskcontext_as_row()).first()
... )
>>> isinstance(anum, int)
True
>>> isinstance(partid, int)
True
>>> isinstance(stageid, int)
True
>>> isinstance(taskaid, int)
True
>>> prop
'value'
>>> isinstance(cpus, int)
True
Get a task context instance from a dataframe via Pandas UDF.
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.functions import pandas_udf
>>> @pandas_udf("STRUCT<"
... "anum: INT, partid: INT, stageid: INT, taskaid: INT, prop: STRING, cpus: INT>")
... def taskcontext_as_row(_):
... taskcontext = TaskContext.get()
... return pd.DataFrame({
... "anum": [taskcontext.attemptNumber()],
... "partid": [taskcontext.partitionId()],
... "stageid": [taskcontext.stageId()],
... "taskaid": [taskcontext.taskAttemptId()],
... "prop": [taskcontext.getLocalProperty("key3")],
... "cpus": [taskcontext.cpus()]
... }) # doctest: +SKIP
...
>>> spark.sparkContext.setLocalProperty("key3", "value") # doctest: +SKIP
>>> [(anum, partid, stageid, taskaid, prop, cpus)] = (
... spark.range(1).select(taskcontext_as_row("id")).first()
... ) # doctest: +SKIP
>>> isinstance(anum, int)
True
>>> isinstance(partid, int)
True
>>> isinstance(stageid, int)
True
>>> isinstance(taskaid, int)
True
>>> prop
'value'
>>> isinstance(cpus, int)
True
"""
_taskContext: ClassVar[Optional["TaskContext"]] = None
_attemptNumber: Optional[int] = None
_partitionId: Optional[int] = None
_stageId: Optional[int] = None
_taskAttemptId: Optional[int] = None
_localProperties: Optional[Dict[str, str]] = None
_cpus: Optional[int] = None
_resources: Optional[Dict[str, "ResourceInformation"]] = None
def __new__(cls: Type["TaskContext"]) -> "TaskContext":
"""
Even if users construct :class:`TaskContext` instead of using get, give them the singleton.
"""
taskContext = cls._taskContext
if taskContext is not None:
return taskContext
cls._taskContext = taskContext = object.__new__(cls)
return taskContext
@classmethod
def _getOrCreate(cls: Type["TaskContext"]) -> "TaskContext":
"""Internal function to get or create global :class:`TaskContext`."""
if cls._taskContext is None:
cls._taskContext = TaskContext()
return cls._taskContext
@classmethod
def _setTaskContext(cls: Type["TaskContext"], taskContext: "TaskContext") -> None:
cls._taskContext = taskContext
[docs] @classmethod
def get(cls: Type["TaskContext"]) -> Optional["TaskContext"]:
"""
Return the currently active :class:`TaskContext`. This can be called inside of
user functions to access contextual information about running tasks.
Returns
-------
:class:`TaskContext`, optional
Notes
-----
Must be called on the worker, not the driver. Returns ``None`` if not initialized.
"""
return cls._taskContext
[docs] def stageId(self) -> int:
"""
The ID of the stage that this task belong to.
Returns
-------
int
current stage id.
"""
return cast(int, self._stageId)
[docs] def partitionId(self) -> int:
"""
The ID of the RDD partition that is computed by this task.
Returns
-------
int
current partition id.
"""
return cast(int, self._partitionId)
[docs] def attemptNumber(self) -> int:
"""
How many times this task has been attempted. The first task attempt will be assigned
attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
Returns
-------
int
current attempt number.
"""
return cast(int, self._attemptNumber)
[docs] def taskAttemptId(self) -> int:
"""
An ID that is unique to this task attempt (within the same :class:`SparkContext`,
no two task attempts will share the same attempt ID). This is roughly equivalent
to Hadoop's `TaskAttemptID`.
Returns
-------
int
current task attempt id.
"""
return cast(int, self._taskAttemptId)
[docs] def getLocalProperty(self, key: str) -> Optional[str]:
"""
Get a local property set upstream in the driver, or None if it is missing.
Parameters
----------
key : str
the key of the local property to get.
Returns
-------
int
the value of the local property.
"""
return cast(Dict[str, str], self._localProperties).get(key, None)
[docs] def cpus(self) -> int:
"""
CPUs allocated to the task.
Returns
-------
int
the number of CPUs.
"""
return cast(int, self._cpus)
[docs] def resources(self) -> Dict[str, "ResourceInformation"]:
"""
Resources allocated to the task. The key is the resource name and the value is information
about the resource.
Returns
-------
dict
a dictionary of a string resource name, and :class:`ResourceInformation`.
"""
from pyspark.resource import ResourceInformation
return cast(Dict[str, "ResourceInformation"], self._resources)
BARRIER_FUNCTION = 1
ALL_GATHER_FUNCTION = 2
def _load_from_socket(
port: Optional[Union[str, int]],
auth_secret: str,
function: int,
all_gather_message: Optional[str] = None,
) -> List[str]:
"""
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.
"""
(sockfile, sock) = local_connect_and_auth(port, auth_secret)
# The call may block forever, so no timeout
sock.settimeout(None)
if function == BARRIER_FUNCTION:
# Make a barrier() function call.
write_int(function, sockfile)
elif function == ALL_GATHER_FUNCTION:
# Make a all_gather() function call.
write_int(function, sockfile)
write_with_length(cast(str, all_gather_message).encode("utf-8"), sockfile)
else:
raise ValueError("Unrecognized function type")
sockfile.flush()
# Collect result.
len = read_int(sockfile)
res = []
for i in range(len):
res.append(UTF8Deserializer().loads(sockfile))
# Release resources.
sockfile.close()
sock.close()
return res
[docs]class BarrierTaskContext(TaskContext):
"""
A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage.
Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task.
.. versionadded:: 2.4.0
Notes
-----
This API is experimental
Examples
--------
Set a barrier, and execute it with RDD.
>>> from pyspark import BarrierTaskContext
>>> def block_and_do_something(itr):
... taskcontext = BarrierTaskContext.get()
... # Do something.
...
... # Wait until all tasks finished.
... taskcontext.barrier()
...
... return itr
...
>>> rdd = spark.sparkContext.parallelize([1])
>>> rdd.barrier().mapPartitions(block_and_do_something).collect()
[1]
"""
_port: ClassVar[Optional[Union[str, int]]] = None
_secret: ClassVar[Optional[str]] = None
@classmethod
def _getOrCreate(cls: Type["BarrierTaskContext"]) -> "BarrierTaskContext":
"""
Internal function to get or create global :class:`BarrierTaskContext`. We need to make sure
:class:`BarrierTaskContext` is returned from here because it is needed in python worker
reuse scenario, see SPARK-25921 for more details.
"""
if not isinstance(cls._taskContext, BarrierTaskContext):
cls._taskContext = object.__new__(cls)
return cls._taskContext
[docs] @classmethod
def get(cls: Type["BarrierTaskContext"]) -> "BarrierTaskContext":
"""
Return the currently active :class:`BarrierTaskContext`.
This can be called inside of user functions to access contextual information about
running tasks.
Notes
-----
Must be called on the worker, not the driver. Returns ``None`` if not initialized.
An Exception will raise if it is not in a barrier stage.
This API is experimental
"""
if not isinstance(cls._taskContext, BarrierTaskContext):
raise PySparkRuntimeError(
error_class="NOT_IN_BARRIER_STAGE",
message_parameters={},
)
return cls._taskContext
@classmethod
def _initialize(
cls: Type["BarrierTaskContext"], port: Optional[Union[str, int]], secret: str
) -> None:
"""
Initialize :class:`BarrierTaskContext`, other methods within :class:`BarrierTaskContext`
can only be called after BarrierTaskContext is initialized.
"""
cls._port = port
cls._secret = secret
[docs] def barrier(self) -> None:
"""
Sets a global barrier and waits until all tasks in this stage hit this barrier.
Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks
in the same stage have reached this routine.
.. versionadded:: 2.4.0
Notes
-----
This API is experimental
In a barrier stage, each task much have the same number of `barrier()`
calls, in all possible code branches. Otherwise, you may get the job hanging
or a `SparkException` after timeout.
"""
if self._port is None or self._secret is None:
raise PySparkRuntimeError(
error_class="CALL_BEFORE_INITIALIZE",
message_parameters={
"func_name": "barrier",
"object": "BarrierTaskContext",
},
)
else:
_load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
[docs] def allGather(self, message: str = "") -> List[str]:
"""
This function blocks until all tasks in the same stage have reached this routine.
Each task passes in a message and returns with a list of all the messages passed in
by each of those tasks.
.. versionadded:: 3.0.0
Notes
-----
This API is experimental
In a barrier stage, each task much have the same number of `barrier()`
calls, in all possible code branches. Otherwise, you may get the job hanging
or a `SparkException` after timeout.
"""
if not isinstance(message, str):
raise TypeError("Argument `message` must be of type `str`")
elif self._port is None or self._secret is None:
raise PySparkRuntimeError(
error_class="CALL_BEFORE_INITIALIZE",
message_parameters={
"func_name": "allGather",
"object": "BarrierTaskContext",
},
)
else:
return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)
[docs] def getTaskInfos(self) -> List["BarrierTaskInfo"]:
"""
Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage,
ordered by partition ID.
.. versionadded:: 2.4.0
Notes
-----
This API is experimental
Examples
--------
>>> from pyspark import BarrierTaskContext
>>> rdd = spark.sparkContext.parallelize([1])
>>> barrier_info = rdd.barrier().mapPartitions(
... lambda _: [BarrierTaskContext.get().getTaskInfos()]).collect()[0][0]
>>> barrier_info.address
'...:...'
"""
if self._port is None or self._secret is None:
raise PySparkRuntimeError(
error_class="CALL_BEFORE_INITIALIZE",
message_parameters={
"func_name": "getTaskInfos",
"object": "BarrierTaskContext",
},
)
else:
addresses = cast(Dict[str, str], self._localProperties).get("addresses", "")
return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]
[docs]class BarrierTaskInfo:
"""
Carries all task infos of a barrier task.
.. versionadded:: 2.4.0
Attributes
----------
address : str
The IPv4 address (host:port) of the executor that the barrier task is running on
Notes
-----
This API is experimental
"""
def __init__(self, address: str) -> None:
self.address = address
def _test() -> None:
import doctest
import sys
from pyspark.sql import SparkSession
globs = globals().copy()
globs["spark"] = (
SparkSession.builder.master("local[2]").appName("taskcontext tests").getOrCreate()
)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs["spark"].stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()