Skip to content

Commit cd406fa

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
Fix: Bug fixes for s3 path validation, mlflow app creation (#5402)
* fix: Fix the recipe selection for multiple recipe scenario * fix: Fix the recipe selection for multiple recipe scenario * fix: Hyperparameter issue fixes, validate s3 output path,additional unit tests * Fix: Add validation to bedrock reward models * Fix: Add validation to bedrock reward models * Fix: Add allow list for bedrock eval models * Fix: Add allow list for bedrock eval models * Fix: Bug fixes for s3 path validation, mlflow app creation * Fix: Update Legal verbiage, and allowed reward model ids based on region * Fix: Update model_package_group_name to model_package_group in all trianers to maintain consistency --------- Co-authored-by: Roja Reddy Sareddy <[email protected]>
1 parent 5e7a3ef commit cd406fa

21 files changed

+695
-813
lines changed

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,19 @@ def _create_mlflow_app(sagemaker_session) -> Optional[MlflowApp]:
198198
if new_app.status in ["Created", "Updated"]:
199199
return new_app
200200
elif new_app.status in ["Failed", "Stopped"]:
201-
raise RuntimeError(f"MLflow app creation failed with status: {new_app.status}")
201+
# Get detailed error from MLflow app
202+
error_msg = f"MLflow app creation failed with status: {new_app.status}"
203+
if hasattr(new_app, 'failure_reason') and new_app.failure_reason:
204+
error_msg += f". Reason: {new_app.failure_reason}"
205+
raise RuntimeError(error_msg)
202206
time.sleep(poll_interval)
203207

204-
raise RuntimeError(f"MLflow app creation timed out after {max_wait_time} seconds")
208+
# Timeout case - get current status and any error details
209+
new_app.refresh()
210+
error_msg = f"MLflow app creation failed. Current status: {new_app.status}"
211+
if hasattr(new_app, 'failure_reason') and new_app.failure_reason:
212+
error_msg += f". Reason: {new_app.failure_reason}"
213+
raise RuntimeError(error_msg)
205214

206215
except Exception as e:
207216
logger.error("Failed to create MLflow app: %s", e)
@@ -693,7 +702,7 @@ def _validate_eula_for_gated_model(model, accept_eula, is_gated_model):
693702

694703

695704
def _validate_s3_path_exists(s3_path: str, sagemaker_session):
696-
"""Validate if S3 path exists and is accessible."""
705+
"""Validate S3 path and create bucket/prefix if they don't exist."""
697706
if not s3_path.startswith("s3://"):
698707
raise ValueError(f"Invalid S3 path format: {s3_path}")
699708

@@ -705,19 +714,34 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session):
705714
s3_client = sagemaker_session.boto_session.client('s3')
706715

707716
try:
708-
# Check if bucket exists and is accessible
709-
s3_client.head_bucket(Bucket=bucket_name)
717+
# Check if bucket exists, create if it doesn't
718+
try:
719+
s3_client.head_bucket(Bucket=bucket_name)
720+
except Exception as e:
721+
if "NoSuchBucket" in str(e) or "Not Found" in str(e):
722+
# Create bucket
723+
region = sagemaker_session.boto_region_name
724+
if region == 'us-east-1':
725+
s3_client.create_bucket(Bucket=bucket_name)
726+
else:
727+
s3_client.create_bucket(
728+
Bucket=bucket_name,
729+
CreateBucketConfiguration={'LocationConstraint': region}
730+
)
731+
else:
732+
raise
710733

711-
# If prefix is provided, check if it exists
734+
# If prefix is provided, check if it exists, create if it doesn't
712735
if prefix:
713736
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1)
714737
if 'Contents' not in response:
715-
raise ValueError(f"S3 prefix '{prefix}' does not exist in bucket '{bucket_name}'")
738+
# Create the prefix by putting an empty object
739+
if not prefix.endswith('/'):
740+
prefix += '/'
741+
s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b'')
716742

717743
except Exception as e:
718-
if "NoSuchBucket" in str(e):
719-
raise ValueError(f"S3 bucket '{bucket_name}' does not exist or is not accessible")
720-
raise ValueError(f"Failed to validate S3 path '{s3_path}': {str(e)}")
744+
raise ValueError(f"Failed to validate/create S3 path '{s3_path}': {str(e)}")
721745

722746

723747
def _validate_hyperparameter_values(hyperparameters: dict):

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242

4343
HUB_NAME = "SageMakerPublicHub"
4444

45-
# Allowed reward model IDs for RLAIF trainer
46-
_ALLOWED_REWARD_MODEL_IDS = [
47-
"openai.gpt-oss-120b-1:0",
48-
"openai.gpt-oss-20b-1:0",
49-
"qwen.qwen3-32b-v1:0",
50-
"qwen.qwen3-coder-30b-a3b-v1:0"
51-
]
45+
# Allowed reward model IDs for RLAIF trainer with region restrictions
46+
_ALLOWED_REWARD_MODEL_IDS = {
47+
"openai.gpt-oss-120b-1:0": ["us-west-2", "us-east-1", "ap-northeast-1", "eu-west-1"],
48+
"openai.gpt-oss-20b-1:0": ["us-west-2", "us-east-1", "ap-northeast-1", "eu-west-1"],
49+
"qwen.qwen3-32b-v1:0": ["us-west-2", "us-east-1", "ap-northeast-1", "eu-west-1"],
50+
"qwen.qwen3-coder-30b-a3b-v1:0": ["us-west-2", "us-east-1", "ap-northeast-1", "eu-west-1"],
51+
"qwen.qwen3-coder-480b-a35b-v1:0": ["us-west-2", "ap-northeast-1"],
52+
"qwen.qwen3-235b-a22b-2507-v1:0": ["us-west-2", "ap-northeast-1"]
53+
}
5254

5355
# Allowed evaluator models for LLM as Judge evaluator with region restrictions
5456
_ALLOWED_EVALUATOR_MODELS = {

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class DPOTrainer(BaseTrainer):
4141
trainer = DPOTrainer(
4242
model="meta-llama/Llama-2-7b-hf",
4343
training_type=TrainingType.LORA,
44-
model_package_group_name="my-model-group",
44+
model_package_group="my-model-group",
4545
training_dataset="s3://bucket/preference_data.jsonl"
4646
)
4747
@@ -50,7 +50,7 @@ class DPOTrainer(BaseTrainer):
5050
# Complete workflow: create -> wait -> get model package ARN
5151
trainer = DPOTrainer(
5252
model="meta-llama/Llama-2-7b-hf",
53-
model_package_group_name="my-dpo-models"
53+
model_package_group="my-dpo-models"
5454
)
5555
5656
# Create training job (non-blocking)
@@ -75,7 +75,7 @@ class DPOTrainer(BaseTrainer):
7575
training_type (Union[TrainingType, str]):
7676
The fine-tuning approach. Valid values are TrainingType.LORA (default),
7777
TrainingType.FULL.
78-
model_package_group_name (Optional[Union[str, ModelPackageGroup]]):
78+
model_package_group (Optional[Union[str, ModelPackageGroup]]):
7979
The model package group for storing the fine-tuned model. Can be a group name,
8080
ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
8181
mlflow_resource_arn (Optional[str]):
@@ -86,9 +86,9 @@ class DPOTrainer(BaseTrainer):
8686
mlflow_run_name (Optional[str]):
8787
The MLflow run name for this training job.
8888
training_dataset (Optional[Union[str, DataSet]]):
89-
The training dataset with preference pairs. Can be an S3 URI, dataset ARN, or DataSet object.
89+
The training dataset with preference pairs. Can be a dataset ARN, or DataSet object.
9090
validation_dataset (Optional[Union[str, DataSet]]):
91-
The validation dataset. Can be an S3 URI, dataset ARN, or DataSet object.
91+
The validation dataset. Can be a dataset ARN, or DataSet object.
9292
s3_output_path (Optional[str]):
9393
The S3 path for training job outputs.
9494
If not specified, defaults to s3://sagemaker-<region>-<account>/output.
@@ -101,7 +101,7 @@ def __init__(
101101
self,
102102
model: Union[str, ModelPackage],
103103
training_type: Union[TrainingType, str] = TrainingType.LORA,
104-
model_package_group_name: Optional[Union[str, ModelPackageGroup]] = None,
104+
model_package_group: Optional[Union[str, ModelPackageGroup]] = None,
105105
mlflow_resource_arn: Optional[str] = None,
106106
mlflow_experiment_name: Optional[str] = None,
107107
mlflow_run_name: Optional[str] = None,
@@ -118,8 +118,8 @@ def __init__(
118118
self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session)
119119
self.training_type = training_type
120120

121-
self.model_package_group_name = _validate_and_resolve_model_package_group(model,
122-
model_package_group_name)
121+
self.model_package_group = _validate_and_resolve_model_package_group(model,
122+
model_package_group)
123123
self.mlflow_resource_arn = mlflow_resource_arn
124124
self.mlflow_experiment_name = mlflow_experiment_name
125125
self.mlflow_run_name = mlflow_run_name
@@ -232,7 +232,7 @@ def train(self,
232232
_validate_hyperparameter_values(final_hyperparameters)
233233

234234
model_package_config = _create_model_package_config(
235-
model_package_group_name=self.model_package_group_name,
235+
model_package_group_name=self.model_package_group,
236236
model=self.model,
237237
sagemaker_session=sagemaker_session
238238
)

sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ class LLMAsJudgeEvaluator(BaseEvaluator):
2323
2424
This evaluator uses foundation models to evaluate LLM responses
2525
based on various quality and responsible AI metrics.
26-
26+
27+
This feature is powered by Amazon Bedrock Evaluations. Your use of this feature is subject to pricing of
28+
Amazon Bedrock Evaluations, the Service Terms applicable to Amazon Bedrock, and the terms that apply to your
29+
usage of third-party models. Amazon Bedrock Evaluations may securely transmit data across AWS Regions within your
30+
geography for processing. For more information, access Amazon Bedrock Evaluations documentation.
31+
32+
Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/evaluation-judge.html
33+
2734
Attributes:
2835
evaluator_model (str): AWS Bedrock foundation model identifier to use as the judge.
2936
Required. For supported models, see:

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class RLAIFTrainer(BaseTrainer):
4444
trainer = RLAIFTrainer(
4545
model="meta-llama/Llama-2-7b-hf",
4646
training_type=TrainingType.LORA,
47-
model_package_group_name="my-model-group",
47+
model_package_group="my-model-group",
4848
reward_model_id="reward-model-id",
4949
reward_prompt="Rate the helpfulness of this response on a scale of 1-10",
5050
training_dataset="s3://bucket/rlaif_data.jsonl"
@@ -55,7 +55,7 @@ class RLAIFTrainer(BaseTrainer):
5555
# Complete workflow: create -> wait -> get model package ARN
5656
trainer = RLAIFTrainer(
5757
model="meta-llama/Llama-2-7b-hf",
58-
model_package_group_name="my-rlaif-models",
58+
model_package_group="my-rlaif-models",
5959
reward_model_id="reward-model-id",
6060
reward_prompt="Rate the helpfulness of this response on a scale of 1-10"
6161
)
@@ -82,7 +82,7 @@ class RLAIFTrainer(BaseTrainer):
8282
training_type (Union[TrainingType, str]):
8383
The fine-tuning approach. Valid values are TrainingType.LORA (default),
8484
TrainingType.FULL.
85-
model_package_group_name (Optional[Union[str, ModelPackageGroup]]):
85+
model_package_group (Optional[Union[str, ModelPackageGroup]]):
8686
The model package group for storing the fine-tuned model. Can be a group name,
8787
ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
8888
reward_model_id (str):
@@ -100,9 +100,9 @@ class RLAIFTrainer(BaseTrainer):
100100
mlflow_run_name (Optional[str]):
101101
The MLflow run name for this training job.
102102
training_dataset (Optional[Union[str, DataSet]]):
103-
The training dataset. Can be an S3 URI, dataset ARN, or DataSet object.
103+
The training dataset. Can be a dataset ARN, or DataSet object.
104104
validation_dataset (Optional[Union[str, DataSet]]):
105-
The validation dataset. Can be an S3 URI, dataset ARN, or DataSet object.
105+
The validation dataset. Can be a dataset ARN, or DataSet object.
106106
s3_output_path (Optional[str]):
107107
The S3 path for training job outputs.
108108
If not specified, defaults to s3://sagemaker-<region>-<account>/output.
@@ -116,7 +116,7 @@ def __init__(
116116
self,
117117
model: Union[str, ModelPackage],
118118
training_type: Union[TrainingType, str] = TrainingType.LORA,
119-
model_package_group_name: Optional[Union[str, ModelPackageGroup]] = None,
119+
model_package_group: Optional[Union[str, ModelPackageGroup]] = None,
120120
reward_model_id: str = None,
121121
reward_prompt: Union[str, Evaluator] = None,
122122
mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] = None,
@@ -138,8 +138,8 @@ def __init__(
138138
self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session)
139139

140140
self.training_type = training_type
141-
self.model_package_group_name = _validate_and_resolve_model_package_group(model,
142-
model_package_group_name)
141+
self.model_package_group = _validate_and_resolve_model_package_group(model,
142+
model_package_group)
143143
self.reward_model_id = self._validate_reward_model_id(reward_model_id)
144144
self.reward_prompt = reward_prompt
145145
self.mlflow_resource_arn = mlflow_resource_arn
@@ -173,8 +173,20 @@ def _validate_reward_model_id(self, reward_model_id):
173173
if reward_model_id not in _ALLOWED_REWARD_MODEL_IDS:
174174
raise ValueError(
175175
f"Invalid reward_model_id '{reward_model_id}'. "
176-
f"Available models are: {_ALLOWED_REWARD_MODEL_IDS}"
176+
f"Available models are: {list(_ALLOWED_REWARD_MODEL_IDS.keys())}"
177177
)
178+
179+
# Check region compatibility
180+
session = self.sagemaker_session if hasattr(self, 'sagemaker_session') and self.sagemaker_session else TrainDefaults.get_sagemaker_session()
181+
current_region = session.boto_region_name
182+
allowed_regions = _ALLOWED_REWARD_MODEL_IDS[reward_model_id]
183+
184+
if current_region not in allowed_regions:
185+
raise ValueError(
186+
f"Reward model '{reward_model_id}' is not available in region '{current_region}'. "
187+
f"Available regions for this model: {allowed_regions}"
188+
)
189+
178190
return reward_model_id
179191

180192

@@ -239,7 +251,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
239251
_validate_hyperparameter_values(final_hyperparameters)
240252

241253
model_package_config = _create_model_package_config(
242-
model_package_group_name=self.model_package_group_name,
254+
model_package_group_name=self.model_package_group,
243255
model=self.model,
244256
sagemaker_session=sagemaker_session
245257
)

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class RLVRTrainer(BaseTrainer):
4242
trainer = RLVRTrainer(
4343
model="meta-llama/Llama-2-7b-hf",
4444
training_type=TrainingType.LORA,
45-
model_package_group_name="my-model-group",
45+
model_package_group="my-model-group",
4646
custom_reward_function="arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/my-evaluator/1.0",
4747
training_dataset="s3://bucket/rlvr_data.jsonl"
4848
)
@@ -52,7 +52,7 @@ class RLVRTrainer(BaseTrainer):
5252
# Complete workflow: create -> wait -> get model package ARN
5353
trainer = RLVRTrainer(
5454
model="meta-llama/Llama-2-7b-hf",
55-
model_package_group_name="my-rlvr-models",
55+
model_package_group="my-rlvr-models",
5656
custom_reward_function="arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/my-evaluator/1.0"
5757
)
5858
@@ -78,7 +78,7 @@ class RLVRTrainer(BaseTrainer):
7878
training_type (Union[TrainingType, str]):
7979
The fine-tuning approach. Valid values are TrainingType.LORA (default),
8080
TrainingType.FULL.
81-
model_package_group_name (Optional[Union[str, ModelPackageGroup]]):
81+
model_package_group (Optional[Union[str, ModelPackageGroup]]):
8282
The model package group for storing the fine-tuned model. Can be a group name,
8383
ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
8484
custom_reward_function (Optional[Union[str, Evaluator]]):
@@ -92,9 +92,9 @@ class RLVRTrainer(BaseTrainer):
9292
mlflow_run_name (Optional[str]):
9393
The MLflow run name for this training job.
9494
training_dataset (Optional[Union[str, DataSet]]):
95-
The training dataset. Can be an S3 URI, dataset ARN, or DataSet object.
95+
The training dataset. Can be a dataset ARN, or DataSet object.
9696
validation_dataset (Optional[Union[str, DataSet]]):
97-
The validation dataset. Can be an S3 URI, dataset ARN, or DataSet object.
97+
The validation dataset. Can be a dataset ARN, or DataSet object.
9898
s3_output_path (Optional[str]):
9999
The S3 path for training job outputs.
100100
If not specified, defaults to s3://sagemaker-<region>-<account>/output.
@@ -108,7 +108,7 @@ def __init__(
108108
self,
109109
model: Union[str, ModelPackage],
110110
training_type: Union[TrainingType, str] = TrainingType.LORA,
111-
model_package_group_name: Optional[Union[str, ModelPackageGroup]] = None,
111+
model_package_group: Optional[Union[str, ModelPackageGroup]] = None,
112112
custom_reward_function: Optional[Union[str, Evaluator]] = None,
113113
mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] = None,
114114
mlflow_experiment_name: Optional[str] = None,
@@ -129,8 +129,8 @@ def __init__(
129129
self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session)
130130

131131
self.training_type = training_type
132-
self.model_package_group_name = _validate_and_resolve_model_package_group(model,
133-
model_package_group_name)
132+
self.model_package_group = _validate_and_resolve_model_package_group(model,
133+
model_package_group)
134134
self.custom_reward_function = custom_reward_function
135135
self.mlflow_resource_arn = mlflow_resource_arn
136136
self.mlflow_experiment_name = mlflow_experiment_name
@@ -239,7 +239,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
239239
_validate_hyperparameter_values(final_hyperparameters)
240240

241241
model_package_config = _create_model_package_config(
242-
model_package_group_name=self.model_package_group_name,
242+
model_package_group_name=self.model_package_group,
243243
model=self.model,
244244
sagemaker_session=sagemaker_session
245245
)

0 commit comments

Comments
 (0)