FlaxΒΆ
About this page
This is an API reference for FLax in BentoML. Please refer to /frameworks/flax for more information about how to use Flax in BentoML.
- bentoml.flax.save_model(name: Tag | str, module: nn.Module, state: dict[str, t.Any] | FrozenDict[str, t.Any] | struct.PyTreeNode, *, signatures: ModelSignaturesType | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None) bentoml.Model ΒΆ
Save a
flax.linen.Module
model instance to the BentoML model store.- Parameters:
name β The name to give to the model in the BentoML store. This must be a valid
Tag
name.module β
flax.linen.Module
to be saved.signatures β Signatures of predict methods to be used. If not provided, the signatures default to
predict
. SeeModelSignature
for more details.labels β A default set of management labels to be associated with the model. An example is
{"training-set": "data-1"}
.custom_objects β Custom objects to be saved with the model. An example is
{"my-normalizer": normalizer}
. Custom objects are currently serialized with cloudpickle, but this implementation is subject to change.external_modules β user-defined additional python modules to be saved alongside the model or custom objects, e.g. a tokenizer module, preprocessor module, model configuration module.
metadata β Metadata to be associated with the model. An example is
{"bias": 4}
. Metadata is intended for display in a model management UI and therefore must be a default Python type, such asstr
orint
.
- Returns:
A tag that can be used to access the saved model from the BentoML model store.
- Return type:
Tag
Example:
import jax rng, init_rng = jax.random.split(rng) state = create_train_state(init_rng, config) for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_loss, train_accuracy = train_epoch( state, train_ds, config.batch_size, input_rng ) _, test_loss, test_accuracy = apply_model( state, test_ds["image"], test_ds["label"] ) logger.info( "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f", epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100 ) # `Save` the model with BentoML tag = bentoml.flax.save_model("mnist", CNN(), state)
- bentoml.flax.load_model(bento_model: str | Tag | bentoml.Model, init: bool = True, device: str | XlaBackend = 'cpu') tuple[nn.Module, dict[str, t.Any]] ΒΆ
Load the
flax.linen.Module
model instance with the given tag from the local BentoML model store.- Parameters:
bento_model β Either the tag of the model to get from the store, or a BentoML ~bentoml.Model instance to load the model from.
init β Whether to initialize the state dict of given
flax.linen.Module
. By default, the weights and values will be put tojnp.ndarray
. Ifinit
is set toFalse
, The state_dict will only be put to given accelerator device instead.device β The device to put the state dict to. By default, it will be put to
cpu
. This is only used wheninit
is set toFalse
.
- Returns:
A tuple of
flax.linen.Module
as well as itsstate_dict
from the model store.
Example:
import bentoml import jax net, state_dict = bentoml.flax.load_model("mnist:latest") predict_fn = jax.jit(lambda s: net.apply({"params": state_dict["params"]}, x)) results = predict_fn(jnp.ones((1, 28, 28, 1)))
- bentoml.flax.get(tag_like: str | Tag) bentoml.Model ΒΆ
Get the BentoML model with the given tag.
- Parameters:
tag_like β The tag of the model to retrieve from the model store.
- Returns:
A BentoML
Model
with the matching tag.- Return type:
Model
Example:
import bentoml model = bentoml.flax.get("mnist:latest")