Source code for _bentoml_impl.client.http

from __future__ import annotations

import inspect
import io
import json
import logging
import mimetypes
import os
import pathlib
import tempfile
import time
import typing as t
from abc import abstractmethod
from functools import cached_property
from http import HTTPStatus
from urllib.parse import urljoin
from urllib.parse import urlparse

import attr
import httpx

from _bentoml_sdk import IODescriptor
from _bentoml_sdk.typing_utils import is_image_type
from bentoml import __version__
from bentoml._internal.utils.uri import is_http_url
from bentoml._internal.utils.uri import uri_to_path
from bentoml.exceptions import BentoMLException
from bentoml.exceptions import NotFound
from bentoml.exceptions import ServiceUnavailable

from ..serde import Payload
from ..tasks import ResultStatus
from .base import AbstractClient
from .base import ClientEndpoint
from .base import map_exception
from .task import AsyncTask
from .task import Task

if t.TYPE_CHECKING:
    from httpx._types import RequestFiles
    from PIL import Image

    from _bentoml_sdk import Service
    from bentoml._internal.external_typing import ASGIApp

    from ..serde import Serde

    T = t.TypeVar("T", bound="HTTPClient[t.Any]")
    A = t.TypeVar("A")
C = t.TypeVar("C", httpx.Client, httpx.AsyncClient)
AnyClient = t.TypeVar("AnyClient", httpx.Client, httpx.AsyncClient)
logger = logging.getLogger("bentoml.io")
MAX_RETRIES = 3


def to_async_iterable(iterable: t.Iterable[A]) -> t.AsyncIterable[A]:
    async def _gen() -> t.AsyncIterator[A]:
        for item in iterable:
            yield item

    return _gen()


@attr.define
class HTTPClient(AbstractClient, t.Generic[C]):
    client_cls: t.ClassVar[type[httpx.Client] | type[httpx.AsyncClient]]

    url: str
    endpoints: dict[str, ClientEndpoint] = attr.field(factory=dict)
    media_type: str = "application/json"
    timeout: float = 30
    default_headers: dict[str, str] = attr.field(factory=dict)
    app: ASGIApp | None = None

    _opened_files: list[io.BufferedReader] = attr.field(init=False, factory=list)
    _temp_dir: tempfile.TemporaryDirectory[str] = attr.field(init=False)

    @staticmethod
    def _make_client(
        client_cls: type[AnyClient],
        url: str,
        headers: t.Mapping[str, str],
        timeout: float,
        app: ASGIApp | None = None,
    ) -> AnyClient:
        parsed = urlparse(url)
        transport = None
        if parsed.scheme == "file":
            uds = uri_to_path(url)
            if client_cls is httpx.Client:
                transport = httpx.HTTPTransport(uds=uds)
            else:
                transport = httpx.AsyncHTTPTransport(uds=uds)
            url = "http://127.0.0.1:3000"
        elif parsed.scheme == "tcp":
            url = f"http://{parsed.netloc}"
        elif app is not None:
            if client_cls is httpx.Client:
                from a2wsgi import ASGIMiddleware

                transport = httpx.WSGITransport(app=ASGIMiddleware(app))
            else:
                transport = httpx.ASGITransport(app=app)
        return client_cls(
            base_url=url,
            transport=transport,  # type: ignore
            headers=headers,
            timeout=timeout,
            follow_redirects=True,
        )

    @_temp_dir.default  # type: ignore
    def default_temp_dir(self) -> tempfile.TemporaryDirectory[str]:
        return tempfile.TemporaryDirectory(prefix="bentoml-client-")

    def __init__(
        self,
        url: str,
        *,
        media_type: str = "application/json",
        service: Service[t.Any] | None = None,
        server_ready_timeout: float | None = None,
        token: str | None = None,
        timeout: float = 30,
        app: ASGIApp | None = None,
    ) -> None:
        """Create a client instance from a URL.

        Args:
            url: The URL of the BentoML service.
            media_type: The media type to use for serialization. Defaults to
                "application/json".

        .. note::

            The client created with this method can only return primitive types without a model.
        """
        routes: dict[str, ClientEndpoint] = {}
        default_headers = {"User-Agent": f"BentoML HTTP Client/{__version__}"}
        if token is None:
            token = os.getenv("BENTO_CLOUD_API_KEY")
        if token:
            default_headers["Authorization"] = f"Bearer {token}"

        if service is not None:
            for name, method in service.apis.items():
                routes[name] = ClientEndpoint(
                    name=name,
                    route=method.route,
                    input=method.input_spec.model_json_schema(),
                    output=method.output_spec.model_json_schema(),
                    doc=method.doc,
                    input_spec=method.input_spec,
                    output_spec=method.output_spec,
                    stream_output=method.is_stream,
                    is_task=method.is_task,
                )

            from bentoml._internal.context import server_context

            default_headers.update(
                {
                    "Bento-Name": server_context.bento_name,
                    "Bento-Version": server_context.bento_version,
                    "Runner-Name": service.name,
                    "Yatai-Bento-Deployment-Name": server_context.yatai_bento_deployment_name,
                    "Yatai-Bento-Deployment-Namespace": server_context.yatai_bento_deployment_namespace,
                }
            )
        self.__attrs_init__(  # type: ignore
            url=url,
            endpoints=routes,
            media_type=media_type,
            default_headers=default_headers,
            timeout=timeout,
            app=app,
        )
        if app is None and (server_ready_timeout is None or server_ready_timeout > 0):
            self.wait_until_server_ready(server_ready_timeout)
        if service is None:
            schema_url = urljoin(url, "/schema.json")

            with self._make_client(
                httpx.Client, url, default_headers, timeout, app=app
            ) as client:
                resp = client.get("/schema.json")

                if resp.is_error:
                    raise BentoMLException(f"Failed to fetch schema from {schema_url}")
                for route in resp.json()["routes"]:
                    self.endpoints[route["name"]] = ClientEndpoint(
                        name=route["name"],
                        route=route["route"],
                        input=route["input"],
                        output=route["output"],
                        doc=route.get("doc"),
                        stream_output=route["output"].get("is_stream", False),
                        is_task=route.get("is_task", False),
                    )
        super().__init__()

    @cached_property
    def client(self) -> C:
        return self._make_client(
            self.client_cls, self.url, self.default_headers, self.timeout, self.app
        )

    @cached_property
    def serde(self) -> Serde:
        from ..serde import ALL_SERDE

        return ALL_SERDE[self.media_type]()

    def _build_request(
        self,
        endpoint: ClientEndpoint,
        args: t.Sequence[t.Any],
        kwargs: dict[str, t.Any],
        headers: t.Mapping[str, str],
    ) -> httpx.Request:
        from opentelemetry import propagate

        from _bentoml_sdk.io_models import IORootModel

        headers = httpx.Headers({"Content-Type": self.media_type, **headers})
        propagate.inject(headers)
        if endpoint.input_spec is not None:
            model = endpoint.input_spec.from_inputs(*args, **kwargs)
            if (
                not isinstance(model, IORootModel)
                and model.multipart_fields
                and self.media_type == "application/json"
            ):
                return self._build_multipart(endpoint, model, headers)
            elif isinstance(rendered := model.model_dump(), (str, bytes)):
                headers.update({"content-type": model.mime_type()})
                return self.client.build_request(
                    "POST", endpoint.route, content=rendered, headers=headers
                )
            else:
                payload = self.serde.serialize_model(model)
                headers.update(payload.headers)
                return self.client.build_request(
                    "POST",
                    endpoint.route,
                    headers=headers,
                    content=to_async_iterable(payload.data)
                    if self.client_cls is httpx.AsyncClient
                    else payload.data,
                )
        assert self.media_type == "application/json", (
            "Non-JSON request is not supported"
        )
        if endpoint.input.get("root_input", False):
            if len(args) > 1 or kwargs:
                raise TypeError("Expected one positional argument for root input")
            if not args:
                return self.client.build_request(
                    "POST", endpoint.route, headers=headers
                )
            value = args[0]
            passthrough = False
            content = None
            if "properties" in endpoint.input:
                kwargs = value
                args = ()
                passthrough = True
            elif endpoint.input.get("type") == "file":
                file = self._get_file(value)
                if isinstance(file, str):
                    content = file
                else:
                    file_io, content_type = file[1:]
                    content = iter(lambda: file_io.read(4096), b"")
                    if content_type:
                        headers.update({"content-type": content_type})
            elif isinstance(value, (str, bytes)):
                content = value.encode("utf-8") if isinstance(value, str) else value
                headers.update({"content-type": "text/plain"})
            else:
                payload = self.serde.serialize(value, endpoint.input)
                headers.update(payload.headers)
                content = (
                    to_async_iterable(payload.data)
                    if self.client_cls is httpx.AsyncClient
                    else payload.data
                )
            if not passthrough:
                return self.client.build_request(
                    "POST", endpoint.route, content=content, headers=headers
                )

        for name, value in zip(endpoint.input["properties"], args):
            if name in kwargs:
                raise TypeError(f"Duplicate argument {name}")
            kwargs[name] = value

        params = set(endpoint.input["properties"].keys())
        non_exist_args = set(kwargs.keys()) - set(params)
        if non_exist_args:
            raise TypeError(
                f"Arguments not found in endpoint {endpoint.name}: {non_exist_args}"
            )
        required = set(endpoint.input.get("required", []))
        missing_args = set(required) - set(kwargs.keys())
        if missing_args:
            raise TypeError(
                f"Missing required arguments in endpoint {endpoint.name}: {missing_args}"
            )
        has_file = any(
            schema.get("type") == "file"
            or schema.get("type") == "array"
            and schema["items"].get("type") == "file"
            for schema in endpoint.input["properties"].values()
        )
        if has_file:
            return self._build_multipart(endpoint, kwargs, headers)
        payload = self.serde.serialize(kwargs, endpoint.input)
        headers.update(payload.headers)
        return self.client.build_request(
            "POST",
            endpoint.route,
            content=to_async_iterable(payload.data)
            if self.client_cls is httpx.AsyncClient
            else payload.data,
            headers=headers,
        )

    def wait_until_server_ready(self, timeout: int | None = None) -> None:
        if timeout is None:
            timeout = self.timeout
        with self._make_client(
            httpx.Client, self.url, self.default_headers, timeout
        ) as client:
            start = time.monotonic()
            while time.monotonic() - start < timeout:
                try:
                    resp = client.get("/readyz")
                    if resp.status_code == 200:
                        return
                except (httpx.TimeoutException, httpx.ConnectError):
                    pass
        raise ServiceUnavailable(f"Server is not ready after {timeout} seconds")

    def _get_file(self, value: t.Any) -> str | tuple[str, t.IO[bytes], str | None]:
        if isinstance(value, str) and not is_http_url(value):
            value = pathlib.Path(value)
        if is_image_type(type(value)):
            fp = getattr(value, "_fp", value.fp)
            fname = getattr(fp, "name", None)
            fmt = value.format.lower()
            return (
                pathlib.Path(fname).name if fname else f"upload-image.{fmt}",
                fp,
                f"image/{fmt}",
            )
        elif isinstance(value, pathlib.PurePath):
            file = open(value, "rb")
            self._opened_files.append(file)
            return (value.name, file, mimetypes.guess_type(value)[0])
        elif isinstance(value, str):
            return value
        else:
            assert isinstance(value, t.BinaryIO)
            filename = pathlib.Path(getattr(value, "name", "upload-file")).name
            content_type = mimetypes.guess_type(filename)[0]
            return (filename, value, content_type)

    def _build_multipart(
        self,
        endpoint: ClientEndpoint,
        model: IODescriptor | dict[str, t.Any],
        headers: httpx.Headers,
    ) -> httpx.Request:
        def is_file_field(k: str) -> bool:
            if isinstance(model, IODescriptor):
                return k in model.multipart_fields
            if (f := endpoint.input["properties"].get(k, {})).get("type") == "file":
                return True
            if f.get("type") == "array" and f["items"].get("type") == "file":
                return True
            return False

        if isinstance(model, dict):
            fields = model
        else:
            fields = {k: getattr(model, k) for k in model.model_fields}
        data: dict[str, t.Any] = {}
        files: RequestFiles = []

        for name, value in fields.items():
            if not is_file_field(name):
                data[name] = json.dumps(value)
                continue
            if not isinstance(value, (list, tuple)):
                value = [value]

            for v in value:
                file = self._get_file(v)
                if isinstance(file, str):
                    data[name] = file
                else:
                    files.append((name, file))
        headers.pop("content-type", None)
        return self.client.build_request(
            "POST", endpoint.route, data=data, files=files, headers=headers
        )

    def _deserialize_output(self, payload: Payload, endpoint: ClientEndpoint) -> t.Any:
        from _bentoml_sdk.io_models import IORootModel

        data = iter(payload.data)
        if (endpoint.output.get("type")) == "string":
            content = bytes(next(data))
            if endpoint.output.get("format") == "binary":
                return content
            return content.decode("utf-8")
        elif endpoint.output_spec is not None:
            model = self.serde.deserialize_model(payload, endpoint.output_spec)
            if isinstance(model, IORootModel):
                return model.root  # type: ignore
            return model
        else:
            return self.serde.deserialize(payload, endpoint.output)

    def call(self, __name: str, /, *args: t.Any, **kwargs: t.Any) -> t.Any:
        try:
            endpoint = self.endpoints[__name]
        except KeyError:
            raise NotFound(f"Endpoint {__name} not found") from None
        if endpoint.stream_output:
            return self._get_stream(endpoint, args, kwargs)
        else:
            return self._call(endpoint, args, kwargs)

    @abstractmethod
    def _call(
        self,
        endpoint: ClientEndpoint,
        args: t.Sequence[t.Any],
        kwargs: dict[str, t.Any],
        *,
        headers: t.Mapping[str, str] | None = None,
    ) -> t.Any: ...

    @abstractmethod
    def _get_stream(
        self, endpoint: ClientEndpoint, args: t.Any, kwargs: t.Any
    ) -> t.Any: ...


class SyncHTTPClient(HTTPClient[httpx.Client]):
    """A synchronous client for BentoML service.

    .. note:: Inner usage ONLY
    """

    client_cls = httpx.Client

    def __enter__(self: T) -> T:
        return self

    def __exit__(self, exc_type: t.Any, exc: t.Any, tb: t.Any) -> None:
        return self.close()

[docs] def is_ready(self, timeout: int | None = None) -> bool: try: resp = self.client.get( "/readyz", timeout=timeout or httpx.USE_CLIENT_DEFAULT ) return resp.status_code == 200 except httpx.TimeoutException: logger.warn("Timed out waiting for runner to be ready") return False
def close(self) -> None: if "client" in vars(self): self.client.close() def _get_stream( self, endpoint: ClientEndpoint, args: t.Any, kwargs: t.Any ) -> t.Generator[t.Any, None, None]: resp = self._call(endpoint, args, kwargs) for data in resp: yield data def request(self, method: str, url: str, **kwargs: t.Any) -> httpx.Response: return self.client.request(method, url, **kwargs) def _submit( self, __endpoint: ClientEndpoint, /, *args: t.Any, **kwargs: t.Any ) -> Task: try: req = self._build_request(__endpoint, args, kwargs, {}) req.url = req.url.copy_with(path=f"{__endpoint.route}/submit") resp = self.client.send(req) if resp.is_error: resp.read() raise BentoMLException( f"Error making request: {resp.status_code}: {resp.text}", error_code=HTTPStatus(resp.status_code), ) data = resp.json() return Task(data["task_id"], __endpoint, self) finally: for f in self._opened_files: f.close() self._opened_files.clear() def _get_task_result(self, __endpoint: ClientEndpoint, /, task_id: str) -> t.Any: resp = self.request( "GET", f"{__endpoint.route}/get", params={"task_id": task_id} ) if resp.is_error: resp.read() raise map_exception(resp) if ( __endpoint.output.get("type") == "file" and self.media_type == "application/json" ): return self._parse_file_response(__endpoint, resp) else: return self._parse_response(__endpoint, resp) def _get_task_status( self, __endpoint: ClientEndpoint, /, task_id: str ) -> ResultStatus: resp = self.client.request( "GET", f"{__endpoint.route}/status", params={"task_id": task_id} ) if resp.is_error: resp.read() raise map_exception(resp) data = resp.json() return ResultStatus(data["status"]) def _cancel_task(self, __endpoint: ClientEndpoint, /, task_id: str) -> None: resp = self.request( "PUT", f"{__endpoint.route}/cancel", params={"task_id": task_id} ) if resp.is_error: resp.read() raise map_exception(resp) def _retry_task(self, __endpoint: ClientEndpoint, /, task_id: str) -> Task: resp = self.request( "POST", f"{__endpoint.route}/retry", params={"task_id": task_id} ) if resp.is_error: resp.read() raise map_exception(resp) data = resp.json() return Task(data["task_id"], __endpoint, self) def _call( self, endpoint: ClientEndpoint, args: t.Sequence[t.Any], kwargs: dict[str, t.Any], *, headers: t.Mapping[str, str] | None = None, ) -> t.Any: try: req = self._build_request(endpoint, args, kwargs, headers or {}) resp = self.client.send(req, stream=endpoint.stream_output) if resp.is_error: resp.read() raise map_exception(resp) if endpoint.stream_output: return self._parse_stream_response(endpoint, resp) elif endpoint.output.get("type") == "file": # file responses are always raw binaries whatever the serde is return self._parse_file_response(endpoint, resp) else: return self._parse_response(endpoint, resp) finally: for f in self._opened_files: f.close() self._opened_files.clear() def _parse_response(self, endpoint: ClientEndpoint, resp: httpx.Response) -> t.Any: payload = Payload((resp.read(),), resp.headers) return self._deserialize_output(payload, endpoint) def _parse_stream_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> t.Generator[t.Any, None, None]: try: for data in resp.iter_bytes(): yield self._deserialize_output(Payload((data,), resp.headers), endpoint) finally: resp.close() def _parse_file_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> pathlib.Path | Image.Image: from PIL import Image from python_multipart.multipart import parse_options_header content_disposition = resp.headers.get("content-disposition") content_type = resp.headers.get("content-type", "") filename: str | None = None if endpoint.output.get("pil"): image_formats = ( [content_type[6:]] if content_type.startswith("image/") else None ) return Image.open(io.BytesIO(resp.read()), formats=image_formats) if content_disposition: _, options = parse_options_header(content_disposition) if b"filename" in options: filename = str( options[b"filename"], resp.charset_encoding or "utf-8", errors="ignore", ) with tempfile.NamedTemporaryFile( "wb", suffix=filename, dir=self._temp_dir.name, delete=False ) as f: f.write(resp.read()) return pathlib.Path(f.name) class AsyncHTTPClient(HTTPClient[httpx.AsyncClient]): """An asynchronous client for BentoML service. .. note:: Inner usage ONLY """ client_cls = httpx.AsyncClient
[docs] async def is_ready(self, timeout: int | None = None) -> bool: try: resp = await self.client.get( "/readyz", timeout=timeout or httpx.USE_CLIENT_DEFAULT ) return resp.status_code == 200 except httpx.TimeoutException: logger.warn("Timed out waiting for runner to be ready") return False
async def _get_stream( self, endpoint: ClientEndpoint, args: t.Any, kwargs: t.Any ) -> t.AsyncGenerator[t.Any, None]: resp = await self._call(endpoint, args, kwargs) assert inspect.isasyncgen(resp) async for data in resp: yield data async def __aenter__(self: T) -> T: return self async def __aexit__(self, *args: t.Any) -> None: return await self.close() async def request(self, method: str, url: str, **kwargs: t.Any) -> httpx.Response: return await self.client.request(method, url, **kwargs) async def _submit( self, __endpoint: ClientEndpoint, /, *args: t.Any, **kwargs: t.Any ) -> AsyncTask: try: req = self._build_request(__endpoint, args, kwargs, {}) req.url = req.url.copy_with(path=f"{__endpoint.route}/submit") resp = await self.client.send(req) if resp.is_error: resp.read() raise BentoMLException( f"Error making request: {resp.status_code}: {resp.text}", error_code=HTTPStatus(resp.status_code), ) data = resp.json() return AsyncTask(data["task_id"], __endpoint, self) finally: for f in self._opened_files: f.close() self._opened_files.clear() async def _get_task_status( self, __endpoint: ClientEndpoint, /, task_id: str ) -> ResultStatus: resp = await self.client.request( "GET", f"{__endpoint.route}/status", params={"task_id": task_id} ) if resp.is_error: await resp.aread() raise map_exception(resp) data = resp.json() return ResultStatus(data["status"]) async def _cancel_task(self, __endpoint: ClientEndpoint, /, task_id: str) -> None: resp = await self.request( "PUT", f"{__endpoint.route}/cancel", params={"task_id": task_id} ) if resp.is_error: await resp.aread() raise map_exception(resp) async def _retry_task( self, __endpoint: ClientEndpoint, /, task_id: str ) -> AsyncTask: resp = await self.request( "POST", f"{__endpoint.route}/retry", params={"task_id": task_id} ) if resp.is_error: await resp.aread() raise map_exception(resp) data = resp.json() return AsyncTask(data["task_id"], __endpoint, self) async def _get_task_result( self, __endpoint: ClientEndpoint, /, task_id: str ) -> t.Any: resp = await self.request( "GET", f"{__endpoint.route}/get", params={"task_id": task_id} ) if resp.is_error: await resp.aread() raise map_exception(resp) if ( __endpoint.output.get("type") == "file" and self.media_type == "application/json" ): return await self._parse_file_response(__endpoint, resp) else: return await self._parse_response(__endpoint, resp) async def _call( self, endpoint: ClientEndpoint, args: t.Sequence[t.Any], kwargs: dict[str, t.Any], *, headers: t.Mapping[str, str] | None = None, ) -> t.Any: try: req = self._build_request(endpoint, args, kwargs, headers or {}) resp = await self.client.send(req, stream=endpoint.stream_output) if resp.is_error: await resp.aread() raise map_exception(resp) if endpoint.stream_output: return self._parse_stream_response(endpoint, resp) elif endpoint.output.get("type") == "file": # file responses are always raw binaries whatever the serde is return await self._parse_file_response(endpoint, resp) else: return await self._parse_response(endpoint, resp) finally: for f in self._opened_files: f.close() self._opened_files.clear() async def _parse_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> t.Any: data = await resp.aread() return self._deserialize_output(Payload((data,), resp.headers), endpoint) async def _parse_stream_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> t.AsyncGenerator[t.Any, None]: try: async for data in resp.aiter_bytes(): yield self._deserialize_output(Payload((data,), resp.headers), endpoint) finally: await resp.aclose() async def _parse_file_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> pathlib.Path | Image.Image: from PIL import Image from python_multipart.multipart import parse_options_header content_disposition = resp.headers.get("content-disposition") content_type = resp.headers.get("content-type", "") filename: str | None = None if endpoint.output.get("pil"): image_formats = ( [content_type[6:]] if content_type.startswith("image/") else None ) return Image.open(io.BytesIO(await resp.aread()), formats=image_formats) if content_disposition: _, options = parse_options_header(content_disposition) if b"filename" in options: filename = str( options[b"filename"], resp.charset_encoding or "utf-8", errors="ignore", ) with tempfile.NamedTemporaryFile( "wb", suffix=filename, dir=self._temp_dir.name, delete=False ) as f: f.write(await resp.aread()) return pathlib.Path(f.name) async def close(self) -> None: if "client" in vars(self): await self.client.aclose()