Source code for bentoml._internal.frameworks.torchscript

from __future__ import annotations

import logging
import typing as t
from types import ModuleType
from typing import TYPE_CHECKING

import bentoml
from bentoml import Tag

from ...exceptions import NotFound
from ..models.model import Model
from ..models.model import ModelContext
from ..models.model import PartialKwargsModelOptions as ModelOptions
from ..utils.pkg import get_pkg_version
from .common.pytorch import torch

if TYPE_CHECKING:
    from ..models.model import ModelSignaturesType


logger = logging.getLogger(__name__)
MODULE_NAME = "bentoml.torchscript"
MODEL_FILENAME = "saved_model.pt"
API_VERSION = "v1"


[docs] def get(tag_like: str | Tag) -> Model: model = bentoml.models.get(tag_like) if model.info.module not in (MODULE_NAME, __name__): raise NotFound( f"Model {model.tag} was saved with module {model.info.module}, not loading with {MODULE_NAME}." ) return model
[docs] def load_model( bentoml_model: str | Tag | Model, device_id: str | None = "cpu", *, _extra_files: dict[str, t.Any] | None = None, ) -> torch.ScriptModule | tuple[torch.ScriptModule, dict[str, t.Any]]: """ Load a model from BentoML local modelstore with given name. Args: tag: Tag of a saved model in BentoML local modelstore. device_id: Optional devices to put the given model on. Refer to https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device _extra_files: A dictionary of file names and a empty string. See https://pytorch.org/docs/stable/generated/torch.jit.load.html. Returns: :obj:`torch.ScriptModule`: an instance of :obj:`torch.ScriptModule` from BentoML modelstore. Examples: .. code-block:: python import bentoml lit = bentoml.torchscript.load_model('lit_classifier:latest', device_id="cuda:0") """ if isinstance(bentoml_model, (str, Tag)): bentoml_model = get(bentoml_model) if bentoml_model.info.module not in (MODULE_NAME, __name__): raise NotFound( f"Model {bentoml_model.tag} was saved with module {bentoml_model.info.module}, not loading with {MODULE_NAME}." ) weight_file = bentoml_model.path_of(MODEL_FILENAME) model: torch.ScriptModule = torch.jit.load( weight_file, map_location=device_id, _extra_files=_extra_files, ) return model
[docs] def save_model( name: Tag | str, model: torch.ScriptModule, *, signatures: ModelSignaturesType | None = None, labels: t.Dict[str, str] | None = None, custom_objects: t.Dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: t.Dict[str, t.Any] | None = None, _framework_name: str = "torchscript", _module_name: str = MODULE_NAME, _extra_files: dict[str, t.Any] | None = None, ) -> bentoml.Model: """ Save a model instance to BentoML modelstore. Args: name (:code:`str`): Name for given model instance. This should pass Python identifier check. model (`torch.ScriptModule`): Instance of model to be saved signatures (:code:`dict`, `optional`): A dictionary of method names and their corresponding signatures. labels (:code:`Dict[str, str]`, `optional`, default to :code:`None`): user-defined labels for managing models, e.g. team=nlp, stage=dev custom_objects (:code:`Dict[str, Any]]`, `optional`, default to :code:`None`): user-defined additional python objects to be saved alongside the model, e.g. a tokenizer instance, preprocessor function, model configuration json external_modules (:code:`List[ModuleType]`, `optional`, default to :code:`None`): user-defined additional python modules to be saved alongside the model or custom objects, e.g. a tokenizer module, preprocessor module, model configuration module metadata (:code:`Dict[str, Any]`, `optional`, default to :code:`None`): Custom metadata for given model. Returns: :obj:`~bentoml.Tag`: A :obj:`tag` with a format `name:version` where `name` is the user-defined model's name, and a generated `version` by BentoML. Examples: .. code-block:: python import bentoml import torch """ if not isinstance(model, (torch.ScriptModule, torch.jit.ScriptModule)): raise TypeError(f"Given model ({model}) is not a torch.ScriptModule.") if _framework_name == "pytorch_lightning": framework_versions = { "torch": get_pkg_version("torch"), "pytorch_lightning": get_pkg_version("pytorch_lightning"), } else: framework_versions = {"torch": get_pkg_version("torch")} context: ModelContext = ModelContext( framework_name=_framework_name, framework_versions=framework_versions, ) if _extra_files is not None: if metadata is None: metadata = {} metadata["_extra_files"] = [f for f in _extra_files] if signatures is None: signatures = {"__call__": {"batchable": False}} logger.info( 'Using the default model signature for torchscript (%s) for model "%s".', signatures, name, ) with bentoml.models._create( # type: ignore name, module=_module_name, api_version=API_VERSION, labels=labels, signatures=signatures, custom_objects=custom_objects, external_modules=external_modules, options=ModelOptions(), context=context, metadata=metadata, ) as bento_model: torch.jit.save( model, bento_model.path_of(MODEL_FILENAME), _extra_files=_extra_files ) return bento_model
def get_runnable(bento_model: Model): """ Private API: use :obj:`~bentoml.Model.to_runnable` instead. """ from .common.pytorch import PytorchModelRunnable from .common.pytorch import make_pytorch_runnable_method from .common.pytorch import partial_class partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore model_runnable_class = partial_class( PytorchModelRunnable, bento_model=bento_model, loader=load_model, ) for method_name, options in bento_model.info.signatures.items(): method_partial_kwargs = partial_kwargs.get(method_name) model_runnable_class.add_method( make_pytorch_runnable_method(method_name, method_partial_kwargs), name=method_name, batchable=options.batchable, batch_dim=options.batch_dim, input_spec=options.input_spec, output_spec=options.output_spec, ) return model_runnable_class