fit_predict¶
- TabularCloudPredictor.fit_predict(train_data: str | Path | DataFrame, test_data: str | Path | DataFrame, *, predictor_init_args: Dict[str, Any], predictor_fit_args: Dict[str, Any] | None = None, framework_version: str = 'latest', job_name: str | None = None, instance_type: str = 'ml.m5.2xlarge', instance_count: int = 1, volume_size: int = 256, custom_image_uri: str | None = None, timeout: int = 86400, wait: bool = True, predictions_path: str | None = None, backend_kwargs: Dict | None = None) DataFrame | None[source]¶
Fit and predict in a single SageMaker training job.
This is useful for tabular foundation-model workflows (e.g. Mitra) where “fit” is essentially loading a pretrained model. Running fit and predict in the same job avoids the SageMaker startup overhead twice.
For classification tasks, the returned DataFrame matches the output of
TabularCloudPredictor.predict_proba()withinclude_predict=True— the first column is the predicted class and the remaining columns are class probabilities (suffixed_proba). Useautogluon.cloud.utils.utils.split_pred_and_pred_proba()to split.- Parameters:
train_data (Union[str, pathlib.Path, pd.DataFrame]) – Labeled training data, as a DataFrame or local/S3 path to a data file.
test_data (Union[str, pathlib.Path, pd.DataFrame]) – Unlabeled data to predict on, as a DataFrame or local/S3 path to a data file.
predictor_init_args (dict) – Arguments forwarded to
TabularPredictor()(must includelabel).predictor_fit_args (Optional[dict], default = None) – Additional fit args forwarded to
TabularPredictor.fit().predictions_path (Optional[str]) – S3 URL where predictions will be written by the training container. Defaults to
{cloud_output_path}/{job_name}/predictions.csv. Must end in.csvor.parquet.framework_version
job_name
instance_type
instance_count
volume_size
custom_image_uri
:param : :param timeout: Same semantics as
fit(). :param wait: Same semantics asfit(). :param backend_kwargs: Same semantics asfit().- Returns:
Predictions as a DataFrame. Returns
Nonewhenwaitis False.- Return type:
Optional[pd.DataFrame]