from __future__ import annotations

import logging
import os.path
import tempfile
import typing as t
from typing import TYPE_CHECKING

from bentoml._internal.utils import reserve_free_port

from ...bentos import import_bento
from ...bentos import serve
from ...client import Client
from ...client import HTTPClient
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ..bento import Bento
from ..service.loader import load_bento
from ..tag import Tag

    from pyspark.files import SparkFiles
    from pyspark.sql.types import StructType
except ImportError:  # pragma: no cover (trivial error)
    raise MissingDependencyException(
        "'pyspark' is required in order to use module 'bentoml.spark', install pyspark with 'pip install pyspark'. For more information, refer to"

    import pyspark.sql.dataframe
    import pyspark.sql.session

    RecordBatch: t.TypeAlias = t.Any  # pyarrow doesn't have type annotations

logger = logging.getLogger(__name__)

def _distribute_bento(spark: pyspark.sql.session.SparkSession, bento: Bento) -> str:
    temp_dir = tempfile.mkdtemp()
    export_path = bento.export(temp_dir)
    return os.path.basename(export_path)

def _load_bento_spark(bento_tag: Tag):
    load Bento from local bento store or the SparkFiles directory
        return load_bento(bento_tag)
    except Exception:
        # Use the default Bento export file name. This relies on the implementation
        # of _distribute_bento to use default Bento export file name.
        bento_path = SparkFiles.get(f"{}-{bento_tag.version}.bento")
        if not os.path.isfile(bento_path):

        return load_bento(bento_tag)

def _get_process(
    bento_tag: Tag, api_name: str
) -> t.Callable[[t.Iterable[RecordBatch]], t.Generator[RecordBatch, None, None]]:
    def process(
        iterator: t.Iterable[RecordBatch],
    ) -> t.Generator[RecordBatch, None, None]:
        svc = _load_bento_spark(bento_tag)

        assert (
            api_name in svc.apis
        ), "An error occurred transferring the Bento to the Spark worker."
        inference_api = svc.apis[api_name]
        assert inference_api.func is not None, "Inference API function not defined"

        # start bento server
        with reserve_free_port() as port:

        server = serve(bento_tag, port=port)
        Client.wait_until_server_ready("localhost", server.port, 30)
        client = HTTPClient(svc, f"http://localhost:{server.port}")

        for batch in iterator:
            func_input = inference_api.input.from_arrow(batch)
            func_output =, func_input)
            yield inference_api.output.to_arrow(func_output)

    return process

[docs]def run_in_spark( bento: Bento, df: pyspark.sql.dataframe.DataFrame, spark: pyspark.sql.session.SparkSession, api_name: str | None = None, output_schema: StructType | None = None, ) -> pyspark.sql.dataframe.DataFrame: """ Run BentoService inference API in Spark. The API to run must accept batches as input and return batches as output. Args: bento: The bento containing the inference API to run. df: The input DataFrame to run the inference API on. spark: The spark session to use to run the inference API. api_name: The name of the inference API to run. If not provided, there must be only one API contained in the bento; that API will be run. output_schema: The Spark schema of the output DataFrame. If not provided, BentoML will attempt to infer the schema from the output descriptor of the inference API. Returns: The result of the inference API run on the input ``df``. Examples -------- .. code-block:: python >>> import bentoml >>> import pyspark >>> from pyspark.sql import SparkSession >>> from pyspark.sql.types import StructType, StructField, StringType >>> spark = SparkSession.builder.getOrCreate() >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", StringType(), True), ... ]) >>> df = spark.createDataFrame([("John", 30), ("Mike", 25), ("Sally", 40)], schema) >>> bento = bentoml.get("my_service:latest") >>> results = bentoml.batch.run_in_spark(bento, df, spark) >>> +-----+---+ | name|age| +-----+---+ |John |30 | +-----+---+ """ svc = load_bento(bento) if api_name is None: if len(svc.apis) != 1: raise BentoMLException( f'Bento "{bento.tag}" has multiple APIs ({svc.apis.keys()}), specify which API should be run, e.g.: bentoml.batch.run_in_spark("my_service:latest", df, spark, api_name="predict")' ) api_name = next(iter(svc.apis)) else: if api_name not in svc.apis: raise BentoMLException( f"API name '{api_name}' not found in Bento '{bento.tag}', available APIs are {svc.apis.keys()}" ) api = svc.apis[api_name] _distribute_bento(spark, bento) process = _get_process(bento.tag, api_name) if output_schema is None: output_schema = api.output.spark_schema() return df.mapInArrow(process, output_schema)