fit_predict

TabularCloudPredictor.fit_predict(train_data: str | Path | DataFrame, test_data: str | Path | DataFrame, *, predictor_init_args: Dict[str, Any], predictor_fit_args: Dict[str, Any] | None = None, framework_version: str = 'latest', job_name: str | None = None, instance_type: str = 'ml.m5.2xlarge', instance_count: int = 1, volume_size: int = 256, custom_image_uri: str | None = None, timeout: int = 86400, wait: bool = True, predictions_path: str | None = None, backend_kwargs: Dict | None = None) DataFrame | None[source]

Fit and predict in a single SageMaker training job.

This is useful for tabular foundation-model workflows (e.g. Mitra) where “fit” is essentially loading a pretrained model. Running fit and predict in the same job avoids the SageMaker startup overhead twice.

For classification tasks, the returned DataFrame matches the output of TabularCloudPredictor.predict_proba() with include_predict=True — the first column is the predicted class and the remaining columns are class probabilities (suffixed _proba). Use autogluon.cloud.utils.utils.split_pred_and_pred_proba() to split.

Parameters:
  • train_data (Union[str, pathlib.Path, pd.DataFrame]) – Labeled training data, as a DataFrame or local/S3 path to a data file.

  • test_data (Union[str, pathlib.Path, pd.DataFrame]) – Unlabeled data to predict on, as a DataFrame or local/S3 path to a data file.

  • predictor_init_args (dict) – Arguments forwarded to TabularPredictor() (must include label).

  • predictor_fit_args (Optional[dict], default = None) – Additional fit args forwarded to TabularPredictor.fit().

  • predictions_path (Optional[str]) – S3 URL where predictions will be written by the training container. Defaults to {cloud_output_path}/{job_name}/predictions.csv. Must end in .csv or .parquet.

  • framework_version

  • job_name

  • instance_type

  • instance_count

  • volume_size

  • custom_image_uri

:param : :param timeout: Same semantics as fit(). :param wait: Same semantics as fit(). :param backend_kwargs: Same semantics as fit().

Returns:

Predictions as a DataFrame. Returns None when wait is False.

Return type:

Optional[pd.DataFrame]