Source code for autogluon.cloud.model.foundation_model

"""FoundationModel — predict with pretrained foundation models on AWS."""

from __future__ import annotations

import json
import logging
import tarfile
import tempfile
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union

import pandas as pd

from autogluon.common.utils.s3_utils import s3_path_to_bucket_prefix

from ..backend.backend_factory import BackendFactory
from ..backend.constant import SAGEMAKER, TABULAR_SAGEMAKER, TIMESERIES_SAGEMAKER
from ..endpoint.prediction_future import JobPredictionFuture
from ..endpoint.timeseries_endpoint import TimeSeriesEndpoint
from ..scripts.script_manager import ScriptManager
from ..utils.aws_utils import resolve_cloud_output_path
from ..version import __version__
from .registry import get_model_config

logger = logging.getLogger(__name__)

# SageMaker extracts model.tar.gz to /opt/ml/model in the container.
_CONTAINER_WEIGHTS_DIR = "/opt/ml/model/weights"

_AG_CLOUD_VERSION_METADATA_KEY = "autogluon-cloud-version"


def _s3_head_or_none(s3_client: Any, bucket: str, key: str) -> Optional[Dict[str, Any]]:
    """Return ``head_object`` response if the key exists, ``None`` for 404. Other errors propagate."""
    from botocore.exceptions import ClientError

    try:
        return s3_client.head_object(Bucket=bucket, Key=key)
    except ClientError as e:
        if e.response.get("Error", {}).get("Code") in ("404", "NoSuchKey", "NotFound"):
            return None
        raise


class FoundationModel:
    """
    Pretrained foundation model inference on AWS.

    Factory: FoundationModel("chronos-bolt-base", ...) returns the appropriate
    task-specific subclass (TimeSeriesFoundationModel, TabularFoundationModel).

    Examples
    --------
    >>> model = FoundationModel("chronos-bolt-base")
    >>> predictions = model.predict(data, prediction_length=12)
    """

    _backend_map: Dict[str, str] = {}
    _predictor_type: str

    def __new__(cls, model_id: str, **kwargs) -> "FoundationModel":
        if cls is not FoundationModel:
            return super().__new__(cls)
        config = get_model_config(model_id)
        task = config.task
        if task == "forecasting":
            return super().__new__(TimeSeriesFoundationModel)
        elif task in ("classification", "regression"):
            return super().__new__(TabularFoundationModel)
        raise ValueError(f"Unsupported task: {task}")

    def __init__(
        self,
        model_id: str,
        *,
        cloud_output_path: Optional[str] = None,
        role: Optional[str] = None,
        hyperparameters: Optional[Dict[str, Any]] = None,
        model_artifact_uri: Optional[str] = None,
        backend: Literal["sagemaker"] = "sagemaker",
    ):
        """
        Parameters
        ----------
        model_id
            ID of the foundation model from the model registry.
        cloud_output_path
            S3 location where intermediate artifacts are stored. Accepts:

            * ``s3://bucket`` — a unique timestamped subfolder ``ag-<timestamp>`` is appended.
            * ``s3://bucket/prefix`` — used verbatim. Re-running with the same prefix will overwrite previously written
              artifacts.
            * ``None`` (default) — use the bucket saved in ``~/.autogluon/cloud.yaml`` (set by
              :func:`autogluon.cloud.bootstrap` / :func:`autogluon.cloud.register`) and append a timestamped subfolder.
              Raises if no bucket is configured.
        role
            ARN of the SageMaker execution role used to run training and inference jobs. If ``None``, falls back to
            ``role_arn`` in ``~/.autogluon/cloud.yaml`` (set by :func:`autogluon.cloud.bootstrap` /
            :func:`autogluon.cloud.register`), and finally to ``sagemaker.get_execution_role()``.
        hyperparameters
            Default hyperparameters applied to inference and (when supported) training.
        model_artifact_uri
            S3 URI of a pre-bundled ``model.tar.gz`` produced by :meth:`cache_model_artifact`. When set, deploys skip
            the runtime HuggingFace download and load weights from the bundled artifact.
        backend
            Cloud backend to use.
        """
        self.model_id = model_id
        self.model_artifact_uri = model_artifact_uri
        self.cloud_output_path = resolve_cloud_output_path(cloud_output_path, backend_name=backend)
        self._config = get_model_config(model_id)
        self._hyperparameter_overrides = hyperparameters or {}
        self._tmpdir = tempfile.TemporaryDirectory(prefix="ag_fm_")

        backend_name = self._backend_map.get(backend)
        if backend_name is None:
            raise ValueError(
                f"Backend '{backend}' is not supported for {self.__class__.__name__}. "
                f"Available: {list(self._backend_map.keys())}"
            )
        self._backend = BackendFactory.get_backend(
            backend=backend_name,
            local_output_path=self._tmpdir.name,
            cloud_output_path=self.cloud_output_path,
            predictor_type=self._predictor_type,
            role=role,
        )

    def _get_hyperparameters(
        self, context: Literal["inference", "training"], overrides: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Merge registry defaults → constructor overrides → call-site overrides, defaulting ``model_path`` to
        ``model_source_uri`` if not set."""
        if context == "inference":
            registry_defaults = self._config.inference_hyperparameters
        else:
            registry_defaults = self._config.training_hyperparameters
        merged = registry_defaults | self._hyperparameter_overrides | (overrides or {})
        merged.setdefault("model_path", self._config.model_source_uri)
        return merged

    @abstractmethod
    def _build_predictor_init_args(self, **user_kwargs) -> Dict[str, Any]:
        """Build predictor_init_args dict from user-provided kwargs.

        Subclasses override to map their public API kwargs (e.g., prediction_length,
        target, known_covariates_names) to the dict that TimeSeriesPredictor/TabularPredictor expects.
        """
        ...

    @abstractmethod
    def _build_predictor_fit_args(self, hyperparameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """Build predictor_fit_args dict. Subclasses override with task-specific logic."""
        ...

    @property
    @abstractmethod
    def _serve_script_path(self) -> str:
        """Path to the serve script for this model type."""
        ...

    @abstractmethod
    def deploy(self, **kwargs):
        """Deploy model to a real-time endpoint.

        Subclasses implement this and return a task-specific endpoint
        (e.g., TimeSeriesEndpoint, TabularEndpoint).
        """
        ...

    @abstractmethod
    def predict(self, data: Union[str, Path, pd.DataFrame], wait: bool = True, **kwargs) -> Optional[pd.DataFrame]:
        """Subclasses override with task-specific signature."""
        ...

    def _deploy_backend(
        self,
        instance_type: Optional[str] = None,
        endpoint_name: Optional[str] = None,
        hyperparameters: Optional[Dict[str, Any]] = None,
        framework_version: str = "latest",
        custom_image_uri: Optional[str] = None,
        wait: bool = True,
        inference_mode: Literal["realtime", "serverless"] = "realtime",
        inference_config: Optional[Dict[str, Any]] = None,
        **backend_kwargs,
    ) -> None:
        """Shared deploy logic. Subclasses call this then wrap the endpoint."""
        if inference_mode == "serverless" and instance_type is not None:
            raise ValueError("`instance_type` must not be set when `inference_mode='serverless'`.")
        if instance_type is None and inference_mode == "realtime":
            instance_type = self._config.deploy_instance_type

        merged_hp = self._get_hyperparameters("inference", hyperparameters)
        if self.model_artifact_uri is not None:
            user_model_path = (hyperparameters or {}).get("model_path") or self._hyperparameter_overrides.get(
                "model_path"
            )
            if user_model_path is not None:
                raise ValueError(
                    "Cannot set hyperparameters['model_path'] when model_artifact_uri is in use — the bundled artifact "
                    f"determines the in-container weights path ({_CONTAINER_WEIGHTS_DIR}). Drop model_path, or call "
                    "deploy() on a FoundationModel without model_artifact_uri."
                )
            merged_hp["model_path"] = _CONTAINER_WEIGHTS_DIR
        fm_serve_config = {
            "ag_model_key": self._config.ag_model_key,
            "hyperparameters": merged_hp,
        }

        model_kwargs = backend_kwargs.pop("model_kwargs", {})
        model_kwargs["entry_point"] = self._serve_script_path

        # FM deploys never want SDK repack: predictor_path is either None (script-only tarball is built locally) or a
        # pre-bundled cache artifact that already contains the serve script.
        self._backend.deploy(
            predictor_path=self.model_artifact_uri,
            endpoint_name=endpoint_name,
            framework_version=framework_version,
            instance_type=instance_type,
            custom_image_uri=custom_image_uri,
            wait=wait,
            model_kwargs=model_kwargs,
            fm_serve_config=fm_serve_config,
            inference_mode=inference_mode,
            inference_config=inference_config,
            repack=False,
            **backend_kwargs,
        )
        assert self._backend.endpoint is not None

    def fit(
        self,
        train_data: Union[str, Path, pd.DataFrame],
        output_path: Optional[str] = None,
        instance_type: Optional[str] = None,
        hyperparameters: Optional[Dict[str, Any]] = None,
        wait: bool = True,
        **kwargs,
    ) -> "FoundationModel":
        """
        Fine-tune the model. Returns a new FoundationModel pointing to the fine-tuned artifact.

        Parameters
        ----------
        train_data
            Training data, as a DataFrame or local/S3 path to a data file.
        output_path
            S3 path to store fine-tuned model.
            If None, will auto-generate under cloud_output_path.
        instance_type
            Instance type for the training job.
            If None, will use the default from the model registry.
        hyperparameters
            Model hyperparameters for training. Overrides values passed to the constructor.
            Available hyperparameters for each model are listed in the AutoGluon documentation.
        wait
            If True, block until training completes.

        Returns
        -------
        FoundationModel
            New instance with hyperparameters pointing to the fine-tuned artifact.

        :meta private:
        """
        if not self._config.fine_tunable:
            raise ValueError(f"Model '{self.model_id}' does not support fine-tuning.")
        raise NotImplementedError

    def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> "FoundationModel":
        """
        Download model weights from HuggingFace, bundle them with the FM serve script into a SageMaker-compatible
        ``model.tar.gz``, and upload to S3.

        Lets :meth:`deploy` skip the runtime HuggingFace download — required for network-isolated endpoints (e.g.
        SageMaker Serverless Inference). Returns a new :class:`FoundationModel` with ``model_artifact_uri`` set to the
        uploaded tarball.

        Destination key: ``{cache_path}/{model_id}/model.tar.gz``. If it already exists, upload is skipped unless
        ``overwrite=True``; a stale-cache mismatch between the bundled artifact's autogluon-cloud version and the
        current version raises ``RuntimeError`` and prompts the caller to re-bundle.

        Parameters
        ----------
        cache_path
            S3 prefix under which the artifact will be uploaded. Multiple foundation models can share one prefix.
        overwrite
            If True, re-upload even when the destination key exists.

        Returns
        -------
        FoundationModel
            A new instance with ``model_artifact_uri`` populated. The original is unchanged.
        """
        from huggingface_hub import snapshot_download

        if not cache_path.startswith("s3://"):
            raise ValueError(f"cache_path must be an s3:// URI, got: {cache_path!r}")

        source_uri = self._config.model_source_uri
        cache_key = f"{cache_path.rstrip('/')}/{self.model_id}/model.tar.gz"
        bucket, key = s3_path_to_bucket_prefix(cache_key)
        s3 = self._backend.sagemaker_session.boto_session.client("s3")

        head = None if overwrite else _s3_head_or_none(s3, bucket, key)
        if head is not None:
            cached_version = head["Metadata"].get(_AG_CLOUD_VERSION_METADATA_KEY)
            if cached_version != __version__:
                raise RuntimeError(
                    f"Cached artifact at {cache_key} was bundled with autogluon-cloud "
                    f"{cached_version!r}, current is {__version__!r}. "
                    f"Pass overwrite=True to re-bundle and re-upload."
                )
            logger.info(f"Cached artifact already exists at {cache_key}; skipping upload")
        else:
            with tempfile.TemporaryDirectory(prefix="ag_fm_cache_") as tmp:
                tmp_path = Path(tmp)
                weights_dir = tmp_path / "weights"
                logger.info(f"Downloading {source_uri} from HuggingFace to {weights_dir}")
                snapshot_download(repo_id=source_uri, local_dir=str(weights_dir))

                # Mirror the layout produced by SagemakerBackend._create_serve_script_tarball:
                # entry-point script + serving_utils/ under code/, so the cached endpoint can
                # `from serving_utils.timeseries import ...` exactly like a fresh deploy.
                serve_script = Path(self._serve_script_path)
                tarball = tmp_path / "model.tar.gz"
                logger.info(f"Bundling weights + serve script into {tarball}")
                with tarfile.open(tarball, "w:gz") as tar:
                    tar.add(weights_dir, arcname="weights")
                    tar.add(serve_script, arcname=f"code/{serve_script.name}")
                    tar.add(ScriptManager.SAGEMAKER_SERVING_UTILS_DIR, arcname="code/serving_utils")
                logger.info(f"Uploading to {cache_key}")
                s3.upload_file(
                    str(tarball),
                    bucket,
                    key,
                    ExtraArgs={"Metadata": {_AG_CLOUD_VERSION_METADATA_KEY: __version__}},
                )

        return self.__class__(
            model_id=self.model_id,
            hyperparameters=self._hyperparameter_overrides or None,
            model_artifact_uri=cache_key,
            cloud_output_path=self.cloud_output_path,
            role=self._backend.role_arn,
        )

    def to_dict(self) -> Dict[str, Any]:
        """Serialize the model identity. Runtime context (``role``, ``cloud_output_path``) is excluded so configs can
        be shared across users."""
        out: Dict[str, Any] = {"model_id": self.model_id}
        if self._hyperparameter_overrides:
            out["hyperparameters"] = self._hyperparameter_overrides
        if self.model_artifact_uri:
            out["model_artifact_uri"] = self.model_artifact_uri
        return out

    def to_json(self) -> str:
        """Serialize :meth:`to_dict` output as a JSON string."""
        return json.dumps(self.to_dict())

    @classmethod
    def from_dict(cls, config: Dict[str, Any], **runtime_context: Any) -> "FoundationModel":
        """Restore from :meth:`to_dict` output. Pass ``role`` / ``cloud_output_path`` as ``runtime_context``."""
        return cls(**config, **runtime_context)

    @classmethod
    def from_json(cls, s: str, **runtime_context: Any) -> "FoundationModel":
        """Restore from a :meth:`to_json` string."""
        return cls.from_dict(json.loads(s), **runtime_context)


[docs] class TimeSeriesFoundationModel(FoundationModel): """Foundation model for time series forecasting (Chronos, etc.).""" _backend_map = {SAGEMAKER: TIMESERIES_SAGEMAKER} _predictor_type = "timeseries" @property def _serve_script_path(self) -> str: return ScriptManager.SAGEMAKER_TIMESERIES_FM_SERVE_SCRIPT_PATH
[docs] def deploy( self, instance_type: Optional[str] = None, endpoint_name: Optional[str] = None, hyperparameters: Optional[Dict[str, Any]] = None, framework_version: str = "latest", custom_image_uri: Optional[str] = None, wait: bool = True, inference_mode: Literal["realtime", "serverless"] = "realtime", inference_config: Optional[Dict[str, Any]] = None, **backend_kwargs, ) -> TimeSeriesEndpoint: """ Deploy model to an inference endpoint. Parameters ---------- instance_type Instance type for the endpoint. Defaults to the model registry value. Must be ``None`` when ``inference_mode="serverless"``. endpoint_name Custom endpoint name. If None, will auto-generate a unique name. hyperparameters Model hyperparameters for inference. Overrides values passed to the constructor. framework_version Container framework version. If 'latest', uses the most recent available. custom_image_uri Custom Docker image URI for the inference container. wait Whether to block until the endpoint is ready. inference_mode Endpoint type. ``"serverless"`` provisions a SageMaker Serverless Inference endpoint (no instance management, scales to zero). inference_config Mode-specific overrides forwarded to ``sagemaker.serverless.ServerlessInferenceConfig`` (e.g. ``memory_size_in_mb``, ``max_concurrency``). **backend_kwargs Backend-specific arguments (e.g., initial_instance_count, volume_size, model_kwargs, deploy_kwargs). """ self._deploy_backend( instance_type=instance_type, endpoint_name=endpoint_name, hyperparameters=hyperparameters, framework_version=framework_version, custom_image_uri=custom_image_uri, wait=wait, inference_mode=inference_mode, inference_config=inference_config, **backend_kwargs, ) return TimeSeriesEndpoint(self._backend.endpoint)
def _build_predictor_fit_args(self, hyperparameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: merged_hp = self._get_hyperparameters("inference", hyperparameters) return { "hyperparameters": {self._config.ag_model_key: merged_hp}, "skip_model_selection": True, } def _build_predictor_init_args( self, target: str = "target", prediction_length: int = 1, quantile_levels: Optional[List[float]] = None, **kwargs, ) -> Dict[str, Any]: """Map user kwargs to TimeSeriesPredictor init args.""" args: Dict[str, Any] = { "target": target, "prediction_length": prediction_length, } if quantile_levels is not None: args["quantile_levels"] = quantile_levels return args
[docs] def predict( self, data: Union[str, Path, pd.DataFrame], target: str = "target", id_column: str = "item_id", timestamp_column: str = "timestamp", known_covariates: Optional[Union[str, Path, pd.DataFrame]] = None, static_features: Optional[Union[str, Path, pd.DataFrame]] = None, prediction_length: int = 1, quantile_levels: Optional[List[float]] = None, hyperparameters: Optional[Dict[str, Any]] = None, instance_type: Optional[str] = None, framework_version: str = "latest", custom_image_uri: Optional[str] = None, wait: bool = True, predictions_path: Optional[str] = None, **backend_kwargs, ) -> Union[pd.DataFrame, JobPredictionFuture]: """ Run batch prediction for time series. Parameters ---------- data Historical time series to forecast from, in long format, as a DataFrame or local/S3 path to a data file. See the `TimeSeriesPredictor docs <https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesPredictor.html>`_ for the expected format. target Name of the column that contains the target values to forecast. id_column Name of the column with the unique identifier of each time series (item). timestamp_column Name of the column with the observation timestamps. known_covariates Future values of the known covariates over the forecast horizon. Covariate column names are inferred from the columns (excluding ``id_column`` and ``timestamp_column``). static_features Static (time-independent) features describing each individual time series. prediction_length Forecast horizon: how many time steps into the future the model should predict. quantile_levels List of increasing decimals between 0 and 1 specifying which quantiles to estimate. Defaults to ``[0.1, 0.2, ..., 0.9]``. hyperparameters Model hyperparameters for inference. Overrides values passed to the constructor. instance_type Instance type for the prediction job. If None, uses registry default. framework_version Container framework version. custom_image_uri Custom Docker image URI for the container. wait If True, block and return a DataFrame. If False, return a :class:`JobPredictionFuture` immediately — call ``.result()`` on it later to retrieve the DataFrame, or ``.status()`` to check progress. predictions_path S3 URL where predictions will be written by the prediction job (e.g. ``s3://my-bucket/runs/2024-05-01/predictions.csv``). The container's SageMaker execution role must have ``s3:PutObject`` permission for this location. Defaults to ``{cloud_output_path}/{job_name}/predictions.csv``. Predictions use AutoGluon's canonical column names ``item_id`` and ``timestamp``, regardless of the ``id_column`` / ``timestamp_column`` passed in. **backend_kwargs Additional backend-specific arguments (e.g., job_name, volume_size, autogluon_sagemaker_estimator_kwargs). Returns ------- pd.DataFrame or JobPredictionFuture DataFrame if ``wait=True``; a :class:`JobPredictionFuture` otherwise. """ if instance_type is None: instance_type = self._config.predict_instance_type predictor_init_args = self._build_predictor_init_args( target=target, prediction_length=prediction_length, quantile_levels=quantile_levels, ) predictor_fit_args = self._build_predictor_fit_args(hyperparameters) data_channels = { "train_data": data, "known_covariates": known_covariates, "static_features": static_features, } extra_ag_args: Dict[str, Any] = {"predict_after_fit": True} if predictions_path is not None: extra_ag_args["predictions_path"] = predictions_path self._backend.fit( predictor_init_args=predictor_init_args, predictor_fit_args=predictor_fit_args, data_channels=data_channels, id_column=id_column, timestamp_column=timestamp_column, framework_version=framework_version, instance_type=instance_type, custom_image_uri=custom_image_uri, wait=wait, extra_ag_args=extra_ag_args, **backend_kwargs, ) if not wait: return JobPredictionFuture( job=self._backend._fit_job, result_loader=self._backend.get_fit_predict_results, ) return self._backend.get_fit_predict_results()
class TabularFoundationModel(FoundationModel): """Foundation model for tabular prediction (Mitra, TabICL, etc.).""" _backend_map = {SAGEMAKER: TABULAR_SAGEMAKER} _predictor_type = "tabular" @property def _serve_script_path(self) -> str: raise NotImplementedError("Tabular FM deploy is not yet supported") def deploy(self, **kwargs): raise NotImplementedError("Tabular FM deploy is not yet supported") def _build_predictor_init_args(self, label: str = "target", **kwargs) -> Dict[str, Any]: """Map user kwargs to TabularPredictor init args.""" return {"label": label} def predict( self, train_data: Union[str, Path, pd.DataFrame], test_data: Union[str, Path, pd.DataFrame], label: str = "target", hyperparameters: Optional[Dict[str, Any]] = None, instance_type: Optional[str] = None, framework_version: str = "latest", custom_image_uri: Optional[str] = None, wait: bool = True, **backend_kwargs, ) -> Optional[pd.DataFrame]: """ Run batch prediction for tabular tasks. For tabular foundation models (e.g., Mitra), train_data provides the few-shot context and test_data contains the rows to predict on. Parameters ---------- train_data Labeled few-shot context for the foundation model. test_data Unlabeled data to predict on. label Target column name in train_data. hyperparameters Model hyperparameters for inference. Overrides values passed to the constructor. instance_type Instance type for the prediction job. If None, uses registry default. framework_version Container framework version. custom_image_uri Custom Docker image URI for the container. wait If True, block and return DataFrame. If False, return the job handle. **backend_kwargs Additional backend-specific arguments. Returns ------- Optional[pd.DataFrame] """ # TODO: requires fit_predict support for TabularCloudPredictor raise NotImplementedError def predict_proba( self, train_data: Union[str, Path, pd.DataFrame], test_data: Union[str, Path, pd.DataFrame], label: str = "target", hyperparameters: Optional[Dict[str, Any]] = None, output_path: Optional[str] = None, instance_type: Optional[str] = None, wait: bool = True, **backend_kwargs, ) -> Optional[pd.DataFrame]: """ Run batch prediction returning class probabilities. Parameters ---------- train_data Labeled few-shot context for the foundation model. test_data Unlabeled data to predict on. label Target column name in train_data. hyperparameters Model hyperparameters for inference. Overrides values passed to the constructor. Available hyperparameters for each model are listed in the AutoGluon documentation. output_path S3 path to store predictions. If None, will auto-generate under cloud_output_path. instance_type Instance type for the prediction job. If None, will use the default from the model registry. wait If True, block and return DataFrame. If False, return the job handle. **backend_kwargs Additional backend-specific arguments (e.g. job_name, custom_image_uri, framework_version, volume_size). Returns ------- Optional[pd.DataFrame] """ raise NotImplementedError