Source code for bentoml._internal.models.model

from __future__ import annotations

import importlib
import io
import logging
import os
import typing as t
from datetime import datetime
from datetime import timezone
from sys import version_info as pyver
from types import ModuleType
from typing import TYPE_CHECKING
from typing import overload

import attr
import cloudpickle  # type: ignore (no cloudpickle types)
import fs
import fs.errors
import fs.mirror
import yaml
from cattr.gen import make_dict_structure_fn
from cattr.gen import make_dict_unstructure_fn
from cattr.gen import override
from fs.base import FS
from simple_di import Provide
from simple_di import inject

from ...exceptions import BentoMLException
from ...exceptions import NotFound
from ..configuration import BENTOML_VERSION
from ..configuration.containers import BentoMLContainer
from import Store
from import StoreItem
from ..tag import Tag
from ..types import MetadataDict
from ..types import ModelSignatureDict
from ..utils import bentoml_cattr
from ..utils import label_validator
from ..utils import metadata_validator
from ..utils import normalize_labels_value

    from ..runner import Runnable
    from ..runner import Runner
    from ..runner.strategy import Strategy
    from ..types import PathType

T = t.TypeVar("T")

logger = logging.getLogger(__name__)

PYTHON_VERSION: str = f"{pyver.major}.{pyver.minor}.{pyver.micro}"
MODEL_YAML_FILENAME = "model.yaml"
CUSTOM_OBJECTS_FILENAME = "custom_objects.pkl"

class ModelOptions:
    def with_options(self, **kwargs: t.Any) -> ModelOptions:
        return attr.evolve(self, **kwargs)

    def to_dict(self: ModelOptions) -> dict[str, t.Any]:
        return attr.asdict(self)

class PartialKwargsModelOptions(ModelOptions):
    partial_kwargs: t.Dict[str, t.Any] = attr.field(factory=dict)

@attr.define(repr=False, eq=False, init=False)
class Model(StoreItem):
    _tag: Tag
    __fs: FS

    _info: ModelInfo
    _custom_objects: dict[str, t.Any] | None = None

    _runnable: t.Type[Runnable] | None = attr.field(init=False, default=None)

    _model: t.Any = None

    def __init__(
        tag: Tag,
        model_fs: FS,
        info: ModelInfo,
        custom_objects: dict[str, t.Any] | None = None,
        _internal: bool = False,
        if not _internal:
            raise BentoMLException(
                "Model cannot be instantiated directly; use bentoml.<framework>.save or bentoml.models.get instead"

        self.__attrs_init__(tag, model_fs, info, custom_objects)  # type: ignore (no types for attrs init)

    def _export_ext() -> str:
        return "bentomodel"

    def tag(self) -> Tag:
        return self._tag

    def _fs(self) -> FS:
        return self.__fs

    def info(self) -> ModelInfo:
        return self._info

    def custom_objects(self) -> t.Dict[str, t.Any]:
        if self._custom_objects is None:
            if self._fs.isfile(CUSTOM_OBJECTS_FILENAME):
                with, "rb") as cofile:
                    self._custom_objects: dict[str, t.Any] | None = cloudpickle.load(
                    if not isinstance(self._custom_objects, dict):
                        raise ValueError("Invalid custom objects found.")
                self._custom_objects: dict[str, t.Any] | None = {}

        return self._custom_objects

    def __eq__(self, other: object) -> bool:
        return isinstance(other, Model) and self._tag == other._tag

    def __hash__(self) -> int:
        return hash(self._tag)

    def create(
        name: Tag | str,
        module: str,
        api_version: str,
        signatures: ModelSignaturesType,
        labels: dict[str, str] | None = None,
        options: ModelOptions | None = None,
        custom_objects: dict[str, t.Any] | None = None,
        metadata: dict[str, t.Any] | None = None,
        context: ModelContext,
    ) -> Model:
        """Create a new Model instance in temporary filesystem used for serializing
        model artifacts and save to model store

            name: model name in target model store, model version will be automatically
            module: import path of module used for saving/loading this model, e.g.
            labels:  user-defined labels for managing models, e.g. team=nlp, stage=dev
            options: default options for loading this model, defined by runner
                implementation, e.g. xgboost booster_params
            custom_objects: user-defined additional python objects to be saved
                alongside the model, e.g. a tokenizer instance, preprocessor function,
                model configuration json
            metadata: user-defined metadata for storing model training context
                information or model evaluation metrics, e.g. dataset version,
                training parameters, model scores
            context: Environment context managed by BentoML for loading model,
                e.g. {"framework:" "tensorflow", "framework_version": _tf_version}

            object: Model instance created in temporary filesystem
        tag = Tag.from_taglike(name)
        if tag.version is None:
            tag = tag.make_new_version()
        labels = {} if labels is None else labels
        metadata = {} if metadata is None else metadata
        options = ModelOptions() if options is None else options

        model_fs = fs.open_fs(f"temp://bentoml_model_{}")

        return cls(

    def save(
        model_store: ModelStore = Provide[BentoMLContainer.model_store],
    ) -> Model:
        except BentoMLException as e:
            raise BentoMLException(f"Failed to save {self!s}: {e}") from None

        with model_store.register(self.tag) as model_path:
            out_fs = fs.open_fs(model_path, create=True, writeable=True)
            fs.mirror.mirror(self._fs, out_fs, copy_if_newer=False)
            self.__fs = out_fs

        return self

    def from_fs(cls: t.Type[Model], item_fs: FS) -> Model:
            with, "r") as model_yaml:
                info = ModelInfo.from_yaml_file(model_yaml)
        except fs.errors.ResourceNotFound:
            raise BentoMLException(
                f"Failed to load bento model because it does not contain a '{MODEL_YAML_FILENAME}'"

        res = Model(tag=info.tag, model_fs=item_fs, info=info, _internal=True)
        except BentoMLException as e:
            raise BentoMLException(f"Failed to load {res!s}: {e}") from None

        return res

    def enter_cloudpickle_context(
        external_modules: list[ModuleType],
        imported_modules: list[ModuleType],
    ) -> list[ModuleType]:
        Enter a context for cloudpickle to pickle custom objects defined in external modules.

            external_modules: list of external modules to pickle
            imported_modules: list to added modules, needs to be surely unregistered after pickling

            list[ModuleType]: list of module names that were imported in the context

            ValueError: if any of the external modules is not importable
        if not external_modules:
            return []

        registed_before: set[str] = cloudpickle.list_registry_pickle_by_value()
        for mod in external_modules:
            if mod.__name__ in registed_before:

        return imported_modules

    def exit_cloudpickle_context(cls, imported_modules: list[ModuleType]) -> None:
        Exit the context for cloudpickle, unregister imported external modules.

        Needs to be called after self.flush
        if not imported_modules:

        for mod in imported_modules:

    def flush(self):

    def _write_info(self):
        with, "w", encoding="utf-8") as model_yaml:
  , model_yaml))

    def _write_custom_objects(self):
        # pickle custom_objects if it is not None and not empty
        if self.custom_objects:
            with, "wb") as cofile:
                cloudpickle.dump(self.custom_objects, cofile)  # type: ignore (incomplete cloudpickle types)

    def creation_time(self) -> datetime:

    def validate(self):
        if not self._fs.isfile(MODEL_YAML_FILENAME):
            raise BentoMLException(
                f"{self!s} does not contain a {MODEL_YAML_FILENAME}."

    def __str__(self):
        return f'Model(tag="{self.tag}")'

    def __repr__(self):
        return f'Model(tag="{self.tag}", path="{self.path}")'

    def to_runner(
        name: str = "",
        max_batch_size: int | None = None,
        max_latency_ms: int | None = None,
        method_configs: dict[str, dict[str, int]] | None = None,
        embedded: bool = False,
        scheduling_strategy: type[Strategy] | None = None,
    ) -> Runner:
        TODO(chaoyu): add docstring



        from ..runner import Runner
        from ..runner.strategy import DefaultStrategy

        if scheduling_strategy is None:
            scheduling_strategy = DefaultStrategy

        # TODO: @larme @yetone run this branch only yatai version is incompatible with embedded runner
        yatai_version = os.environ.get("YATAI_T_VERSION")
        if embedded and yatai_version:
                f"Yatai of version {yatai_version} is incompatible with embedded runner, set `embedded=False` for runner {name}"
            embedded = False

        return Runner(
            name=name if name != "" else,

    def to_runnable(self) -> t.Type[Runnable]:
        if self._runnable is None:
            self._runnable =
        return self._runnable

    def load_model(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
        Load the model into memory from the model store directory.
        This is a shortcut to the ``load_model`` function defined in the framework module
        used for saving the target model.

        For example, if the ``BentoModel`` is saved with
        ``bentoml.tensorflow.save_model``, this method will pass it to the
        ``bentoml.tensorflow.load_model`` method, along with any additional arguments.
        if self._model is None:
            self._model =, *args, **kwargs)
        return self._model

    def with_options(self, **kwargs: t.Any) -> Model:
        res = Model(
        return res

class ModelStore(Store[Model]):
    def __init__(self, base_path: "t.Union[PathType, FS]"):
        super().__init__(base_path, Model)

class ModelContext:
    framework_name: str
    framework_versions: t.Dict[str, str]

    # using factory explicitly instead of default because omit_if_default is enabled in ModelInfo
    bentoml_version: str = attr.field(factory=lambda: BENTOML_VERSION)
    python_version: str = attr.field(factory=lambda: PYTHON_VERSION)

    def from_dict(data: dict[str, str | dict[str, str]] | ModelContext) -> ModelContext:
        if isinstance(data, ModelContext):
            return data
        return bentoml_cattr.structure(data, ModelContext)

    def to_dict(self: ModelContext) -> dict[str, str | dict[str, str]]:
        return bentoml_cattr.unstructure(self)

[docs]@attr.frozen class ModelSignature: """ A model signature represents a method on a model object that can be called. This information is used when creating BentoML runners for this model. Note that anywhere a ``ModelSignature`` is used, a ``dict`` with keys corresponding to the fields can be used instead. For example, instead of ``{"predict": ModelSignature(batchable=True)}``, one can pass ``{"predict": {"batchable": True}}``. Fields: batchable: Whether multiple API calls to this predict method should be batched by the BentoML runner. batch_dim: The dimension(s) that contain multiple data when passing to this prediction method. For example, if you have two inputs you want to run prediction on, ``[1, 2]`` and ``[3, 4]``, if the array you would pass to the predict method would be ``[[1, 2], [3, 4]]``, then the batch dimension would be ``0``. If the array you would pass to the predict method would be ``[[1, 3], [2, 4]]``, then the batch dimension would be ``1``. If there are multiple arguments to the predict method and there is only one batch dimension supplied, all arguments will use that batch dimension. Example: .. code-block:: python # Save two models with `predict` method that supports taking input batches on the dimension 0 and the other on dimension 1: bentoml.pytorch.save_model("demo0", model_0, signatures={"predict": {"batchable": True, "batch_dim": 0}}) bentoml.pytorch.save_model("demo1", model_1, signatures={"predict": {"batchable": True, "batch_dim": 1}}) # if the following calls are batched, the input to the actual predict method on the # model.predict method would be [[1, 2], [3, 4], [5, 6]] runner0 = bentoml.pytorch.get("demo0:latest").to_runner() runner0.init_local()[[1, 2], [3, 4]]))[[5, 6]])) # if the following calls are batched, the input to the actual predict method on the # model.predict would be [[1, 2, 5], [3, 4, 6]] runner1 = bentoml.pytorch.get("demo1:latest").to_runner() runner1.init_local()[[1, 2], [3, 4]]))[[5], [6]])) Expert API: The batch dimension can also be a tuple of (input batch dimension, output batch dimension). For example, if the predict method should have its input batched along the first axis and its output batched along the zeroth axis, ``batch_dim`` can be set to ``(1, 0)``. input_spec: Reserved for future use. output_spec: Reserved for future use. """ batchable: bool = False batch_dim: t.Tuple[int, int] = (0, 0) # TODO: define input/output spec struct input_spec: t.Any = None output_spec: t.Any = None @classmethod def from_dict(cls, data: ModelSignatureDict) -> ModelSignature: if "batch_dim" in data and isinstance(data["batch_dim"], int): formated_data = dict(data, batch_dim=(data["batch_dim"], data["batch_dim"])) else: formated_data = data return bentoml_cattr.structure(formated_data, cls) @staticmethod def convert_signatures_dict( data: dict[str, ModelSignatureDict | ModelSignature], ) -> dict[str, ModelSignature]: return { k: ModelSignature.from_dict(v) if isinstance(v, dict) else v for k, v in data.items() }
if TYPE_CHECKING: ModelSignaturesType = dict[str, ModelSignature] | dict[str, ModelSignatureDict] def model_signature_unstructure_hook( model_signature: ModelSignature, ) -> dict[str, t.Any]: encoded: dict[str, t.Any] = { "batchable": model_signature.batchable, } # ignore batch_dim if batchable is False if model_signature.batchable: encoded["batch_dim"] = model_signature.batch_dim if model_signature.input_spec is not None: encoded["input_spec"] = model_signature.input_spec if model_signature.output_spec is not None: encoded["output_spec"] = model_signature.output_spec return encoded bentoml_cattr.register_unstructure_hook( ModelSignature, model_signature_unstructure_hook ) @attr.define(repr=False, eq=False, frozen=True, init=False) class ModelInfo: # for backward compatibility in case new fields are added to BentoInfo. __forbid_extra_keys__ = False # omit field in yaml file if it is not provided by the user. __omit_if_default__ = True tag: Tag name: str version: str module: str labels: t.Dict[str, str] = attr.field( validator=label_validator, converter=normalize_labels_value ) _options: t.Dict[str, t.Any] metadata: MetadataDict = attr.field(validator=metadata_validator, converter=dict) context: ModelContext = attr.field() signatures: t.Dict[str, ModelSignature] = attr.field( converter=ModelSignature.convert_signatures_dict ) api_version: str creation_time: datetime _cached_module: t.Optional[ModuleType] = None _cached_options: t.Optional[ModelOptions] = None def __init__( self, tag: Tag, module: str, labels: dict[str, str], options: dict[str, t.Any] | ModelOptions, metadata: MetadataDict, context: ModelContext, signatures: ModelSignaturesType, api_version: str, creation_time: datetime | None = None, ): if isinstance(options, ModelOptions): object.__setattr__(self, "_cached_options", options) options = options.to_dict() if creation_time is None: creation_time = self.__attrs_init__( # type: ignore tag=tag,, version=tag.version, module=module, labels=labels, options=options, metadata=metadata, context=context, signatures=signatures, api_version=api_version, creation_time=creation_time, ) self.validate() def __eq__(self, other: object) -> bool: if not isinstance(other, ModelInfo): return False return ( self.tag == other.tag and self.module == other.module and self.signatures == other.signatures and self.labels == other.labels and self.options == other.options and self.metadata == other.metadata and self.context == other.context and self.signatures == other.signatures and self.api_version == other.api_version and self.creation_time == other.creation_time ) def with_options(self, **kwargs: t.Any) -> ModelInfo: return ModelInfo( tag=self.tag, module=self.module, signatures=self.signatures, labels=self.labels, options=self.options.with_options(**kwargs), metadata=self.metadata, context=self.context, api_version=self.api_version, creation_time=self.creation_time, ) # cached_property doesn't support __slots__ classes @property def imported_module(self) -> ModuleType: if self._cached_module is None: if not self.module: raise BentoMLException( f"Module is not defined in {MODEL_YAML_FILENAME}. If the module argument is not defined when creating a model using `bentoml.models.create`, methods that use `ModelInfo.imported_module` are not supported." ) from None try: object.__setattr__( self, "_cached_module", importlib.import_module(self.module) ) except (ValueError, ModuleNotFoundError) as e: raise BentoMLException( f"Module '{self.module}' defined in {MODEL_YAML_FILENAME} is not found." ) from e assert self._cached_module is not None return self._cached_module @property def options(self) -> ModelOptions: if self._cached_options is None: if self.module and hasattr(self.imported_module, "ModelOptions"): object.__setattr__( self, "_cached_options", self.imported_module.ModelOptions(**self._options), ) else: object.__setattr__(self, "_cached_options", ModelOptions()) assert self._cached_options is not None return self._cached_options def to_dict(self) -> t.Dict[str, t.Any]: return bentoml_cattr.unstructure(self) @overload def dump(self, stream: io.StringIO) -> io.BytesIO: ... @overload def dump(self, stream: None = None) -> None: ... def dump(self, stream: io.StringIO | None = None) -> io.BytesIO | None: return yaml.safe_dump(self.to_dict(), stream=stream, sort_keys=False) # type: ignore (bad yaml types) @classmethod def from_yaml_file(cls, stream: t.IO[t.Any]) -> ModelInfo: try: yaml_content = yaml.safe_load(stream) except yaml.YAMLError as exc: # pragma: no cover - simple error handling logger.error(exc) raise if not isinstance(yaml_content, dict): raise BentoMLException(f"malformed {MODEL_YAML_FILENAME}") yaml_content["tag"] = str( Tag(t.cast(str, yaml_content["name"]), t.cast(str, yaml_content["version"])) ) del yaml_content["name"] del yaml_content["version"] # For backwards compatibility for bentos created prior to version 1.0.0rc1 if "bentoml_version" in yaml_content: del yaml_content["bentoml_version"] if "signatures" not in yaml_content: yaml_content["signatures"] = {} if "context" in yaml_content and "pip_dependencies" in yaml_content["context"]: del yaml_content["context"]["pip_dependencies"] yaml_content["context"]["framework_versions"] = {} try: model_info = bentoml_cattr.structure(yaml_content, cls) except TypeError as e: # pragma: no cover - simple error handling raise BentoMLException(f"unexpected field in {MODEL_YAML_FILENAME}: {e}") return model_info def validate(self): # Validate model.yml file schema, content, bentoml version, etc # add tests when implemented ... bentoml_cattr.register_structure_hook_func( lambda cls: issubclass(cls, ModelInfo), make_dict_structure_fn( ModelInfo, bentoml_cattr, name=override(omit=True), version=override(omit=True), _options=override(rename="options"), ), ) bentoml_cattr.register_unstructure_hook_func( lambda cls: issubclass(cls, ModelInfo), # Ignore tag, tag is saved via the name and version field make_dict_unstructure_fn( ModelInfo, bentoml_cattr, tag=override(omit=True), _options=override(rename="options"), _cached_module=override(omit=True), _cached_options=override(omit=True), ), ) def copy_model( model_tag: t.Union[Tag, str], *, src_model_store: ModelStore, target_model_store: ModelStore, ): """copy a model from src model store to target modelstore, and do nothing if the model tag already exist in target model store """ try: target_model_store.get(model_tag) # if model tag already found in target return except NotFound: pass model = src_model_store.get(model_tag) def _ModelInfo_dumper(dumper: yaml.Dumper, info: ModelInfo) -> yaml.Node: return dumper.represent_dict(info.to_dict()) yaml.add_representer(ModelInfo, _ModelInfo_dumper) # type: ignore (incomplete yaml types)