from __future__ import annotations
import logging
import os.path
import tempfile
import typing as t
from typing import TYPE_CHECKING
from bentoml._internal.utils import reserve_free_port
from ...bentos import import_bento
from ...bentos import serve
from ...client import Client
from ...client import HTTPClient
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ..bento import Bento
from ..service.loader import load_bento
from ..tag import Tag
from pyspark.files import SparkFiles
from pyspark.sql.types import StructType
except ImportError: # pragma: no cover (trivial error)
raise MissingDependencyException(
"'pyspark' is required in order to use module 'bentoml.spark', install pyspark with 'pip install pyspark'. For more information, refer to"
import pyspark.sql.dataframe
import pyspark.sql.session
RecordBatch: t.TypeAlias = t.Any # pyarrow doesn't have type annotations
logger = logging.getLogger(__name__)
def _distribute_bento(spark: pyspark.sql.session.SparkSession, bento: Bento) -> str:
temp_dir = tempfile.mkdtemp()
export_path = bento.export(temp_dir)
return os.path.basename(export_path)
def _load_bento_spark(bento_tag: Tag):
load Bento from local bento store or the SparkFiles directory
return load_bento(bento_tag)
except Exception:
# Use the default Bento export file name. This relies on the implementation
# of _distribute_bento to use default Bento export file name.
bento_path = SparkFiles.get(f"{}-{bento_tag.version}.bento")
if not os.path.isfile(bento_path):
return load_bento(bento_tag)
def _get_process(
bento_tag: Tag, api_name: str
) -> t.Callable[[t.Iterable[RecordBatch]], t.Generator[RecordBatch, None, None]]:
def process(
iterator: t.Iterable[RecordBatch],
) -> t.Generator[RecordBatch, None, None]:
svc = _load_bento_spark(bento_tag)
assert (
api_name in svc.apis
), "An error occurred transferring the Bento to the Spark worker."
inference_api = svc.apis[api_name]
assert inference_api.func is not None, "Inference API function not defined"
# start bento server
with reserve_free_port() as port:
server = serve(bento_tag, port=port)
Client.wait_until_server_ready("localhost", port, 30)
client = HTTPClient(svc, server.url)
for batch in iterator:
func_input = inference_api.input.from_arrow(batch)
func_output =, func_input)
yield inference_api.output.to_arrow(func_output)
return process
def run_in_spark(
bento: Bento,
df: pyspark.sql.dataframe.DataFrame,
spark: pyspark.sql.session.SparkSession,
api_name: str | None = None,
output_schema: StructType | None = None,
) -> pyspark.sql.dataframe.DataFrame:
Run BentoService inference API in Spark.
The API to run must accept batches as input and return batches as output.
The bento containing the inference API to run.
The input DataFrame to run the inference API on.
The spark session to use to run the inference API.
The name of the inference API to run. If not provided, there must be only one API contained
in the bento; that API will be run.
The Spark schema of the output DataFrame. If not provided, BentoML will attempt to infer the
schema from the output descriptor of the inference API.
The result of the inference API run on the input ``df``.
.. code-block:: python
>>> import bentoml
>>> import pyspark
>>> from pyspark.sql import SparkSession
>>> from pyspark.sql.types import StructType, StructField, StringType
>>> spark = SparkSession.builder.getOrCreate()
>>> schema = StructType([
... StructField("name", StringType(), True),
... StructField("age", StringType(), True),
... ])
>>> df = spark.createDataFrame([("John", 30), ("Mike", 25), ("Sally", 40)], schema)
>>> bento = bentoml.get("my_service:latest")
>>> results = bentoml.batch.run_in_spark(bento, df, spark)
| name|age|
|John |30 |
svc = load_bento(bento)
if api_name is None:
if len(svc.apis) != 1:
raise BentoMLException(
f'Bento "{bento.tag}" has multiple APIs ({svc.apis.keys()}), specify which API should be run, e.g.: bentoml.batch.run_in_spark("my_service:latest", df, spark, api_name="predict")'
api_name = next(iter(svc.apis))
if api_name not in svc.apis:
raise BentoMLException(
f"API name '{api_name}' not found in Bento '{bento.tag}', available APIs are {svc.apis.keys()}"
api = svc.apis[api_name]
_distribute_bento(spark, bento)
process = _get_process(bento.tag, api_name)
if output_schema is None:
output_schema = api.output.spark_schema()
return df.mapInArrow(process, output_schema)