Source code for _bentoml_sdk.models.huggingface

from __future__ import annotations

import hashlib
import json
import os
import shutil
import typing as t
from functools import cached_property

import attrs
from fs.base import FS

from bentoml._internal.bento.bento import BentoModelInfo
from bentoml._internal.cloud.schemas.modelschemas import ModelManifestSchema
from bentoml._internal.cloud.schemas.schemasv1 import CreateModelSchema
from bentoml._internal.models.model import ModelContext
from bentoml._internal.tag import Tag
from bentoml._internal.types import PathType

from .base import Model

if t.TYPE_CHECKING:
    from huggingface_hub import HfApi

CONFIG_FILE = "config.json"
DEFAULT_HF_ENDPOINT = "https://huggingface.co"


[docs] @attrs.define(unsafe_hash=True) class HuggingFaceModel(Model[str]): """A model reference to a Hugging Face model. Args: model_id (str): The model tag. E.g. "google-bert/bert-base-uncased". revision (str, optional): The revision to use. Defaults to "main". endpoint (str, optional): The Hugging Face endpoint to use. Defaults to https://huggingface.co. include (List[str], optional): The files to include. Defaults to all files. exclude (List[str], optional): The files to exclude. Defaults to no files. Returns: str: The downloaded model path. """ model_id: str revision: str = "main" endpoint: t.Optional[str] = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT")) include: t.Optional[t.List[str]] = None exclude: t.Optional[t.List[str]] = None @cached_property def _hf_api(self) -> HfApi: from huggingface_hub import HfApi return HfApi(endpoint=self.endpoint) @cached_property def commit_hash(self) -> str: return ( self._hf_api.model_info(self.model_id, revision=self.revision).sha or self.revision )
[docs] def resolve(self, base_path: t.Union[PathType, FS, None] = None) -> str: from huggingface_hub import snapshot_download if isinstance(base_path, FS): base_path = base_path.getsyspath("/") snapshot_path = snapshot_download( self.model_id, revision=self.revision, endpoint=self.endpoint, cache_dir=os.getenv("BENTOML_HF_CACHE_DIR"), allow_patterns=self.include, ignore_patterns=self.exclude, ) if base_path is not None: model_path = os.path.dirname(os.path.dirname(snapshot_path)) os.makedirs(base_path, exist_ok=True) shutil.copytree( model_path, os.path.join(base_path, os.path.basename(model_path)), symlinks=True, dirs_exist_ok=True, ) return snapshot_path
[docs] def to_info(self, alias: str | None = None) -> BentoModelInfo: model_id = self.model_id.lower() metadata = { "model_id": model_id, "revision": self.commit_hash, "endpoint": self.endpoint or DEFAULT_HF_ENDPOINT, "include": self.include, "exclude": self.exclude, } content_hash = hashlib.md5( json.dumps(metadata, sort_keys=True, separators=(",", ":")).encode() ).hexdigest() tag = Tag(model_id.replace("/", "--"), content_hash) return BentoModelInfo( tag, alias=alias, registry="huggingface", metadata=metadata )
[docs] @classmethod def from_info(cls, info: BentoModelInfo) -> HuggingFaceModel: if not info.metadata: name = info.tag.name return cls(model_id=name.replace("--", "/")) model = cls( model_id=info.metadata["model_id"], revision=info.metadata["revision"], endpoint=info.metadata["endpoint"], include=info.metadata.get("include"), exclude=info.metadata.get("exclude"), ) # the commit hash is frozen in the model info, update the cache directly model.__dict__.update(commit_hash=info.metadata["revision"]) return model
def _get_model_size(self, revision: str) -> int: from huggingface_hub.utils import filter_repo_objects info = self._hf_api.model_info( self.model_id, revision=revision, files_metadata=True ) filtered_files = filter_repo_objects( items=info.siblings or [], allow_patterns=self.include, ignore_patterns=self.exclude, key=lambda f: f.rfilename, ) return sum((file.size or 0) for file in filtered_files) def to_create_schema(self) -> CreateModelSchema: context = ModelContext(framework_name="huggingface", framework_versions={}) endpoint = self.endpoint or DEFAULT_HF_ENDPOINT revision = self.commit_hash url = f"{endpoint}/{self.model_id}/tree/{revision}" metadata = { "registry": "huggingface", "model_id": self.model_id, "revision": revision, "endpoint": endpoint, "include": self.include, "exclude": self.exclude, "url": url, } return CreateModelSchema( description="", version=self.to_info().tag.version or revision, manifest=ModelManifestSchema( module="", metadata=metadata, api_version="v1", bentoml_version=context.bentoml_version, size_bytes=self._get_model_size(revision), context=context.to_dict(), options={}, ), ) def __str__(self) -> str: return f"{self.model_id}:{self.revision}"