FlaxΒΆ
About this page
This is an API reference for FLax in BentoML. Please refer to /frameworks/flax for more information about how to use Flax in BentoML.
- bentoml.flax.save_model(name: Tag | str, module: nn.Module, state: dict[str, t.Any] | FrozenDict[str, t.Any] | struct.PyTreeNode, *, signatures: ModelSignaturesType | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None) bentoml.ModelΒΆ
Save a
flax.linen.Modulemodel instance to the BentoML model store.- Parameters:
name β The name to give to the model in the BentoML store. This must be a valid
Tagname.module β
flax.linen.Moduleto be saved.signatures β Signatures of predict methods to be used. If not provided, the signatures default to
predict. SeeModelSignaturefor more details.labels β A default set of management labels to be associated with the model. An example is
{"training-set": "data-1"}.custom_objects β Custom objects to be saved with the model. An example is
{"my-normalizer": normalizer}. Custom objects are currently serialized with cloudpickle, but this implementation is subject to change.external_modules β user-defined additional python modules to be saved alongside the model or custom objects, e.g. a tokenizer module, preprocessor module, model configuration module.
metadata β Metadata to be associated with the model. An example is
{"bias": 4}. Metadata is intended for display in a model management UI and therefore must be a default Python type, such asstrorint.
- Returns:
A tag that can be used to access the saved model from the BentoML model store.
- Return type:
Tag
Example:
import jax rng, init_rng = jax.random.split(rng) state = create_train_state(init_rng, config) for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_loss, train_accuracy = train_epoch( state, train_ds, config.batch_size, input_rng ) _, test_loss, test_accuracy = apply_model( state, test_ds["image"], test_ds["label"] ) logger.info( "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f", epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100 ) # `Save` the model with BentoML tag = bentoml.flax.save_model("mnist", CNN(), state)
- bentoml.flax.load_model(bento_model: str | Tag | bentoml.Model, init: bool = True, device: str | XlaBackend = 'cpu') tuple[nn.Module, dict[str, t.Any]]ΒΆ
Load the
flax.linen.Modulemodel instance with the given tag from the local BentoML model store.- Parameters:
bento_model β Either the tag of the model to get from the store, or a BentoML ~bentoml.Model instance to load the model from.
init β Whether to initialize the state dict of given
flax.linen.Module. By default, the weights and values will be put tojnp.ndarray. Ifinitis set toFalse, The state_dict will only be put to given accelerator device instead.device β The device to put the state dict to. By default, it will be put to
cpu. This is only used wheninitis set toFalse.
- Returns:
A tuple of
flax.linen.Moduleas well as itsstate_dictfrom the model store.
Example:
import bentoml import jax net, state_dict = bentoml.flax.load_model("mnist:latest") predict_fn = jax.jit(lambda s: net.apply({"params": state_dict["params"]}, x)) results = predict_fn(jnp.ones((1, 28, 28, 1)))
- bentoml.flax.get(tag_like: str | Tag) bentoml.ModelΒΆ
Get the BentoML model with the given tag.
- Parameters:
tag_like β The tag of the model to retrieve from the model store.
- Returns:
A BentoML
Modelwith the matching tag.- Return type:
Model
Example:
import bentoml model = bentoml.flax.get("mnist:latest")