from __future__ import annotations
import asyncio
import inspect
import logging
import math
import os
import pathlib
import sys
import typing as t
from functools import lru_cache
from functools import partial
import anyio.to_thread
import attrs
from simple_di import Provide
from simple_di import inject
from typing_extensions import Unpack
from bentoml import Runner
from bentoml._internal.bento.bento import Bento
from bentoml._internal.bento.build_config import BentoEnvSchema
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.context import ServiceContext
from bentoml._internal.models import Model as StoredModel
from bentoml._internal.utils import deprecated
from bentoml._internal.utils import dict_filter_none
from bentoml.exceptions import BentoMLConfigException
from bentoml.exceptions import BentoMLException
from ..images import Image
from ..method import APIMethod
from ..models import BentoModel
from ..models import HuggingFaceModel
from ..models import Model
from .config import ServiceConfig as Config
logger = logging.getLogger("bentoml.io")
T = t.TypeVar("T", bound=object)
if t.TYPE_CHECKING:
from bentoml._internal import external_typing as ext
from bentoml._internal.service.openapi.specification import OpenAPISpecification
from bentoml._internal.utils.circus import Server
from .dependency import Dependency
P = t.ParamSpec("P")
R = t.TypeVar("R")
class _ServiceDecorator(t.Protocol):
def __call__(self, inner: type[T]) -> Service[T]: ...
def with_config(
func: t.Callable[t.Concatenate["Service[t.Any]", P], R],
) -> t.Callable[t.Concatenate["Service[t.Any]", P], R]:
def wrapper(self: Service[t.Any], *args: P.args, **kwargs: P.kwargs) -> R:
self.inject_config()
return func(self, *args, **kwargs)
return wrapper
def convert_envs(envs: t.List[t.Dict[str, t.Any]]) -> t.List[BentoEnvSchema]:
return [BentoEnvSchema(**env) for env in envs]
@attrs.define
class Service(t.Generic[T]):
"""A Bentoml service that can be served by BentoML server."""
config: Config
inner: type[T]
image: t.Optional[Image] = None
envs: t.List[BentoEnvSchema] = attrs.field(factory=list, converter=convert_envs)
bento: t.Optional[Bento] = attrs.field(init=False, default=None)
models: list[Model[t.Any]] = attrs.field(factory=list)
apis: dict[str, APIMethod[..., t.Any]] = attrs.field(factory=dict)
dependencies: dict[str, Dependency[t.Any]] = attrs.field(factory=dict, init=False)
mount_apps: list[tuple[ext.ASGIApp, str, str]] = attrs.field(
factory=list, init=False
)
middlewares: list[tuple[type[ext.AsgiMiddleware], dict[str, t.Any]]] = attrs.field(
factory=list, init=False
)
# service context
context: ServiceContext = attrs.field(init=False, factory=ServiceContext)
working_dir: str = attrs.field(init=False, factory=os.getcwd)
# import info
_caller_module: str = attrs.field(init=False)
_import_str: str | None = attrs.field(init=False, default=None)
def __attrs_post_init__(self) -> None:
from .dependency import Dependency
has_task = False
for field in dir(self.inner):
value = getattr(self.inner, field)
if isinstance(value, Dependency):
self.dependencies[field] = value
elif isinstance(value, StoredModel):
logger.warning(
"`bentoml.models.get()` as the class attribute is not recommended because it requires the model"
f" to exist at import time. Use `{value._attr} = BentoModel({str(value.tag)!r})` instead."
)
self.models.append(BentoModel(value.tag))
elif isinstance(value, Model):
self.models.append(t.cast(Model[t.Any], value))
elif isinstance(value, APIMethod):
if value.is_task:
has_task = True
self.apis[field] = t.cast("APIMethod[..., t.Any]", value)
if has_task:
traffic = self.config.setdefault("traffic", {})
traffic["external_queue"] = True
traffic.setdefault("concurrency", 1)
pre_mount_apps = getattr(self.inner, "__bentoml_mounted_apps__", [])
if pre_mount_apps:
self.mount_apps.extend(pre_mount_apps)
delattr(self.inner, "__bentoml_mounted_apps__")
def __hash__(self):
return hash(self.name)
@_caller_module.default # type: ignore
def _get_caller_module(self) -> str:
if __name__ == "__main__":
return __name__
current_frame = inspect.currentframe()
frame = current_frame
while frame:
this_name = frame.f_globals["__name__"]
if this_name != __name__:
return this_name
frame = frame.f_back
return __name__
def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name!r}>"
@lru_cache
def find_dependent_by_path(self, path: str) -> Service[t.Any]:
"""Find a service by path"""
attr_name, _, path = path.partition(".")
if attr_name not in self.dependencies:
if attr_name in self.all_services():
return self.all_services()[attr_name]
else:
raise BentoMLException(f"Service {attr_name} not found")
dependent = self.dependencies[attr_name]
if dependent.on is None:
raise BentoMLException(f"Service {attr_name} not found")
if path:
return dependent.on.find_dependent_by_path(path)
return dependent
def find_dependent_by_name(self, name: str) -> Service[t.Any]:
"""Find a service by name"""
try:
return self.all_services()[name]
except KeyError:
raise BentoMLException(f"Service {name} not found") from None
@property
def url(self) -> str | None:
"""Get the URL of the service, or None if the service is not served"""
dependency_map = BentoMLContainer.remote_runner_mapping.get()
url = dependency_map.get(self.name)
return url.replace("tcp://", "http://") if url else None
@lru_cache(maxsize=1)
def all_services(self) -> dict[str, Service[t.Any]]:
"""Get a map of the service and all recursive dependencies"""
services: dict[str, Service[t.Any]] = {self.name: self}
for dependency in self.dependencies.values():
if dependency.on is None:
continue
dependents = dependency.on.all_services()
conflict = next(
(
k
for k in dependents
if k in services and dependents[k] is not services[k]
),
None,
)
if conflict:
raise BentoMLConfigException(
f"Dependency conflict: {conflict} is already defined by {services[conflict].inner}"
)
services.update(dependents)
return services
@property
def doc(self) -> str:
from bentoml._internal.bento.bento import get_default_svc_readme
if self.bento is not None:
return self.bento.doc
return get_default_svc_readme(self)
def schema(self) -> dict[str, t.Any]:
return dict_filter_none(
{
"name": self.name,
"type": "service",
"routes": [method.schema() for method in self.apis.values()],
"description": getattr(self.inner, "__doc__", None),
}
)
@property
def name(self) -> str:
name = self.config.get("name") or self.inner.__name__
return name
@property
def import_string(self) -> str:
if self._import_str is None:
import_module = self._caller_module
if import_module == "__main__":
if hasattr(sys.modules["__main__"], "__file__"):
import_module = sys.modules["__main__"].__file__
assert isinstance(import_module, str)
try:
import_module_path = pathlib.Path(import_module).relative_to(
self.working_dir
)
except ValueError:
raise BentoMLException(
"Failed to get service import origin, service object defined in __main__ module is not supported"
)
import_module = str(import_module_path.with_suffix("")).replace(
os.path.sep, "."
)
else:
raise BentoMLException(
"Failed to get service import origin, service object defined interactively in console or notebook is not supported"
)
if self._caller_module not in sys.modules:
raise BentoMLException(
"Failed to get service import origin, service object must be defined in a module"
)
for name, value in vars(sys.modules[self._caller_module]).items():
if value is self:
self._import_str = f"{import_module}:{name}"
break
else:
raise BentoMLException(
"Failed to get service import origin, service object must be assigned to a variable at module level"
)
return self._import_str
def to_asgi(self, is_main: bool = True) -> ext.ASGIApp:
from _bentoml_impl.server.app import ServiceAppFactory
self.inject_config()
factory = ServiceAppFactory(self, is_main=is_main)
return factory()
def mount_asgi_app(
self, app: ext.ASGIApp, path: str = "/", name: str | None = None
) -> None:
self.mount_apps.append((app, path, name)) # type: ignore
def mount_wsgi_app(
self, app: ext.WSGIApp, path: str = "/", name: str | None = None
) -> None:
from a2wsgi import WSGIMiddleware
self.mount_apps.append((WSGIMiddleware(app), path, name)) # type: ignore
def add_asgi_middleware(
self, middleware_cls: type[ext.AsgiMiddleware], **options: t.Any
) -> None:
self.middlewares.append((middleware_cls, options))
def gradio_app_startup_hook(self, max_concurrency: int):
gradio_apps = getattr(self.inner, "__bentoml_gradio_apps__", [])
if gradio_apps:
for gradio_app, path, _ in gradio_apps:
logger.info(f"Initializing gradio app at: {path or '/'}")
blocks = gradio_app.get_blocks()
blocks.queue(default_concurrency_limit=max_concurrency)
if hasattr(blocks, "startup_events"):
# gradio < 5.0
blocks.startup_events()
else:
# gradio >= 5.0
blocks.run_startup_events()
delattr(self.inner, "__bentoml_gradio_apps__")
def __call__(self) -> T:
try:
instance = self.inner()
instance.to_async = _AsyncWrapper(instance, self.apis.keys())
instance.to_sync = _SyncWrapper(instance, self.apis.keys())
return instance
except Exception:
logger.exception("Initializing service error")
raise
@property
def openapi_spec(self) -> OpenAPISpecification:
from .openapi import generate_spec
return generate_spec(self)
def inject_config(self) -> None:
from bentoml._internal.configuration import load_config
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.configuration.containers import config_merger
# XXX: ensure at least one item to make `flatten_dict` work
override_defaults = {
"services": {
name: (svc.config or {"workers": 1})
for name, svc in self.all_services().items()
}
}
load_config(override_defaults=override_defaults, use_version=2)
main_config = BentoMLContainer.config.services[self.name].get()
api_server_keys = (
"traffic",
"metrics",
"logging",
"ssl",
"http",
"grpc",
"backlog",
"runner_probe",
"max_runner_connections",
)
api_server_config = {
k: main_config[k] for k in api_server_keys if main_config.get(k) is not None
}
rest_config = {
k: main_config[k] for k in main_config if k not in api_server_keys
}
existing = t.cast(t.Dict[str, t.Any], BentoMLContainer.config.get())
config_merger.merge(existing, {"api_server": api_server_config, **rest_config})
BentoMLContainer.config.set(existing) # type: ignore
@with_config
@inject
def serve_http(
self,
*,
working_dir: str | None = None,
port: int = Provide[BentoMLContainer.http.port],
host: str = Provide[BentoMLContainer.http.host],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
timeout: int | None = None,
ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
ssl_keyfile_password: str | None = Provide[
BentoMLContainer.ssl.keyfile_password
],
ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
bentoml_home: str = Provide[BentoMLContainer.bentoml_home],
development_mode: bool = False,
reload: bool = False,
threaded: bool = False,
) -> Server:
from _bentoml_impl.server import serve_http
from bentoml._internal.log import configure_logging
configure_logging()
return serve_http(
self,
working_dir=working_dir,
host=host,
port=port,
backlog=backlog,
timeout=timeout,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
bentoml_home=bentoml_home,
development_mode=development_mode,
reload=reload,
threaded=threaded,
)
def on_load_bento(self, bento: Bento) -> None:
service_info = next(svc for svc in bento.info.services if svc.name == self.name)
for model, info in zip(self.models, service_info.models):
# Replace the model version with the one in the Bento
if not isinstance(model, HuggingFaceModel):
continue
model_id = info.metadata.get("model_id") # use the case in bento info
if not model_id:
model_id = info.tag.name.replace("--", "/")
model.model_id = model_id
model.revision = info.tag.version
self.bento = bento
@t.overload
def service(inner: type[T], /) -> Service[T]: ...
@t.overload
def service(
inner: None = ...,
/,
*,
image: Image | None = None,
envs: list[dict[str, t.Any]] | None = None,
**kwargs: Unpack[Config],
) -> _ServiceDecorator: ...
[docs]
def service(
inner: type[T] | None = None,
/,
*,
image: Image | None = None,
envs: list[dict[str, t.Any]] | None = None,
**kwargs: Unpack[Config],
) -> t.Any:
"""Mark a class as a BentoML service.
Example:
@service(traffic={"timeout": 60})
class InferenceService:
@api
def predict(self, input: str) -> str:
return input
"""
config = kwargs
def decorator(inner: type[T]) -> Service[T]:
if isinstance(inner, Service):
raise TypeError("service() decorator can only be applied once")
return Service(config=config, inner=inner, image=image, envs=envs or [])
return decorator(inner) if inner is not None else decorator
[docs]
@deprecated()
def runner_service(runner: Runner, **kwargs: Unpack[Config]) -> Service[t.Any]:
"""Make a service from a legacy Runner"""
if not isinstance(runner, Runner): # type: ignore
raise ValueError(f"Expect an instance of Runner, but got {type(runner)}")
class RunnerHandle(runner.runnable_class):
def __init__(self) -> None:
super().__init__(**runner.runnable_init_params)
RunnerHandle.__name__ = runner.name
apis: dict[str, APIMethod[..., t.Any]] = {}
assert runner.runnable_class.bentoml_runnable_methods__ is not None
for method in runner.runner_methods:
runnable_method = runner.runnable_class.bentoml_runnable_methods__[method.name]
api = APIMethod( # type: ignore
func=runnable_method.func,
batchable=runnable_method.config.batchable,
batch_dim=runnable_method.config.batch_dim,
max_batch_size=method.max_batch_size,
max_latency_ms=method.max_latency_ms,
)
apis[method.name] = api
config: Config = {}
resource_config = runner.resource_config or {}
if (
"nvidia.com/gpu" in runner.runnable_class.SUPPORTED_RESOURCES
and "nvidia.com/gpu" in resource_config
):
gpus: list[int] | str | int = resource_config["nvidia.com/gpu"]
if isinstance(gpus, str):
gpus = int(gpus)
if runner.workers_per_resource > 1:
config["workers"] = {}
workers_per_resource = int(runner.workers_per_resource)
if isinstance(gpus, int):
gpus = list(range(gpus))
for i in gpus:
config["workers"].extend([{"gpus": i}] * workers_per_resource)
else:
resources_per_worker = int(1 / runner.workers_per_resource)
if isinstance(gpus, int):
config["workers"] = [
{"gpus": resources_per_worker}
for _ in range(gpus // resources_per_worker)
]
else:
config["workers"] = [
{"gpus": gpus[i : i + resources_per_worker]}
for i in range(0, len(gpus), resources_per_worker)
]
elif "cpus" in resource_config:
config["workers"] = (
math.ceil(resource_config["cpus"]) * runner.workers_per_resource
)
config.update(kwargs)
return Service(
config=config,
inner=RunnerHandle,
models=[BentoModel(m.tag) for m in runner.models],
apis=apis,
)
class _Wrapper:
def __init__(self, wrapped: t.Any, apis: t.Iterable[str]) -> None:
self.__call = None
for name in apis:
if name == "__call__":
self.__call = self._make_method(wrapped, name)
else:
setattr(self, name, self._make_method(wrapped, name))
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if self.__call is None:
raise TypeError("This service is not callable.")
return self.__call(*args, **kwargs)
def _make_method(self, instance: t.Any, name: str) -> t.Any:
raise NotImplementedError
class _AsyncWrapper(_Wrapper):
def _make_method(self, instance: t.Any, name: str) -> t.Any:
original_func = func = getattr(instance, name).local
while hasattr(original_func, "func"):
original_func = original_func.func
is_async_func = (
asyncio.iscoroutinefunction(original_func)
or (
callable(original_func)
and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore
)
or inspect.isasyncgenfunction(original_func)
)
if is_async_func:
return func
if inspect.isgeneratorfunction(original_func):
async def wrapped_gen(
*args: t.Any, **kwargs: t.Any
) -> t.AsyncGenerator[t.Any, None]:
gen = func(*args, **kwargs)
next_fun = gen.__next__
while True:
try:
yield await anyio.to_thread.run_sync(next_fun)
except StopIteration:
break
except RuntimeError as e:
if "raised StopIteration" in str(e):
break
raise
return wrapped_gen
else:
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any:
return await anyio.to_thread.run_sync(partial(func, **kwargs), *args)
return wrapped
class _SyncWrapper(_Wrapper):
def _make_method(self, instance: t.Any, name: str) -> t.Any:
original_func = func = getattr(instance, name).local
while hasattr(original_func, "func"):
original_func = original_func.func
is_async_func = (
asyncio.iscoroutinefunction(original_func)
or (
callable(original_func)
and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore
)
or inspect.isasyncgenfunction(original_func)
)
if not is_async_func:
return func
if inspect.isasyncgenfunction(original_func):
def wrapped_gen(
*args: t.Any, **kwargs: t.Any
) -> t.Generator[t.Any, None, None]:
agen = func(*args, **kwargs)
loop = asyncio.get_event_loop()
while True:
try:
yield loop.run_until_complete(agen.__anext__())
except StopAsyncIteration:
break
return wrapped_gen
else:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any:
loop = asyncio.get_event_loop()
return loop.run_until_complete(func(*args, **kwargs))
return wrapped