Skip to content

fix(train): Skip default instance_type/instance_count when instance_groups is set#5564

Open
mufaddal-rohawala wants to merge 1 commit intoaws:masterfrom
mufaddal-rohawala:fix/heterogeneous-cluster-defaults
Open

fix(train): Skip default instance_type/instance_count when instance_groups is set#5564
mufaddal-rohawala wants to merge 1 commit intoaws:masterfrom
mufaddal-rohawala:fix/heterogeneous-cluster-defaults

Conversation

@mufaddal-rohawala
Copy link
Member

Issue

Fixes #5555

Description

When creating a ModelTrainer with a Compute config that uses instance_groups (heterogeneous cluster), TrainDefaults.get_compute() and JumpStartTrainDefaults.get_compute() unconditionally inject default instance_type (ml.m5.xlarge) and instance_count (1) when those fields are None.

With heterogeneous clusters, instance_type and instance_count are intentionally None because they are mutually exclusive with instance_groups in the SageMaker CreateTrainingJob API. This causes the API call to include both InstanceType/InstanceCount and InstanceGroups in the ResourceConfig, which SageMaker rejects with:

ValidationException: InstanceType or InstanceCount cannot be specified with InstanceGroups

Changes

Wrapped the default instance_type and instance_count injection in both methods with a if not compute.instance_groups: guard so defaults are only applied for homogeneous cluster configurations.

TrainDefaults.get_compute() — skips setting instance_type and instance_count defaults when instance_groups is present.

JumpStartTrainDefaults.get_compute() — same guard applied. volume_size_in_gb default is still set regardless since it applies to both cluster types.

Testing

This change is a minimal guard condition. When instance_groups is not set, behavior is identical to before. When instance_groups is set, instance_type and instance_count remain None as intended by the caller.

…roups is set

Guard the default injection of instance_type and instance_count in
TrainDefaults.get_compute() and JumpStartTrainDefaults.get_compute()
so that these values are not populated when instance_groups is
configured. The SageMaker API treats instance_type/instance_count
and instance_groups as mutually exclusive in ResourceConfig, and
unconditionally setting defaults causes a ValidationException.

Fixes aws#5555
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ModelTrainer with heterogeneous cluster (instance_groups) fails — defaults.py injects instance_type/instance_count unconditionally

2 participants

Comments