From dd8a507cacf3be03b7649a83f042cd2315babc7b Mon Sep 17 00:00:00 2001 From: Varun Morishetty Date: Thu, 19 Feb 2026 20:13:03 +0000 Subject: [PATCH] resolve alt config resolution for jumpstart models --- .../sagemaker/serve/model_builder_servers.py | 5 +- .../servers/test_model_builder_servers.py | 59 +++++++++++++++++++ .../test_model_builder_servers_coverage.py | 36 +++++++++-- 3 files changed, 94 insertions(+), 6 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 831c37ee14..43af8b4f7a 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -851,11 +851,12 @@ def _build_for_jumpstart(self) -> Model: # Get JumpStart model configuration init_kwargs = get_init_kwargs( model_id=self.model, - model_version=self.model_version or "*", + model_version=self.model_version or "*", region=self.region, instance_type=self.instance_type, tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None), + config_name=getattr(self, 'config_name', None), ) # Configure image URI and environment variables diff --git a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py index 4a57147ae9..b15e77a0b0 100644 --- a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py +++ b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py @@ -45,6 +45,7 @@ def __init__(self): self.framework = None self.framework_version = None self._is_mlflow_model = False + self.config_name = None def _deploy_local_endpoint(self, **kwargs): return Mock() @@ -816,6 +817,64 @@ def test_build_unsupported_image_uri(self, mock_init): self.builder._build_for_jumpstart() self.assertIn("Unsupported", str(ctx.exception)) + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') + @patch.object(MockModelBuilderServers, '_prepare_for_mode') + @patch.object(MockModelBuilderServers, '_create_model') + def test_build_passes_config_name_to_get_init_kwargs(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + """Test that config_name is forwarded to get_init_kwargs.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = "djl-inference:0.21.0" + mock_init_kwargs.env = {"TEST": "value"} + mock_init_kwargs.model_data = "s3://bucket/model.tar.gz" + mock_init.return_value = mock_init_kwargs + mock_djl_res.return_value = ({"config": "value"}, True) + mock_create.return_value = Mock() + self.builder.mode = Mode.LOCAL_CONTAINER + self.builder.image_uri = None + self.builder.config_name = "lmi-optimized" + + self.builder._build_for_jumpstart() + + mock_init.assert_called_once_with( + model_id=self.builder.model, + model_version="*", + region=self.builder.region, + instance_type=self.builder.instance_type, + tolerate_vulnerable_model=None, + tolerate_deprecated_model=None, + config_name="lmi-optimized", + ) + + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') + @patch.object(MockModelBuilderServers, '_prepare_for_mode') + @patch.object(MockModelBuilderServers, '_create_model') + def test_build_passes_none_config_name_when_not_set(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + """Test that config_name defaults to None when not set.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = "djl-inference:0.21.0" + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://bucket/model.tar.gz" + mock_init.return_value = mock_init_kwargs + mock_djl_res.return_value = ({"config": "value"}, True) + mock_create.return_value = Mock() + self.builder.mode = Mode.LOCAL_CONTAINER + self.builder.image_uri = None + self.builder.config_name = None + + self.builder._build_for_jumpstart() + + mock_init.assert_called_once_with( + model_id=self.builder.model, + model_version="*", + region=self.builder.region, + instance_type=self.builder.instance_type, + tolerate_vulnerable_model=None, + tolerate_deprecated_model=None, + config_name=None, + ) + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') @patch.object(MockModelBuilderServers, '_prepare_for_mode') @patch.object(MockModelBuilderServers, '_build_for_djl_jumpstart') diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py index e3e21565a1..02b0962feb 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py @@ -347,10 +347,10 @@ def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, m mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" mock_init_kwargs.env = {} mock_get_kwargs.return_value = mock_init_kwargs - + mock_model = Mock(spec=Model) mock_build_djl.return_value = mock_model - + builder = ModelBuilder( model="huggingface-llm-falcon-7b", role_arn=MOCK_ROLE_ARN, @@ -358,13 +358,41 @@ def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, m mode=Mode.SAGEMAKER_ENDPOINT ) builder._optimizing = False - + result = builder._build_for_jumpstart() - + self.assertEqual(result, mock_model) self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) mock_build_djl.assert_called_once() + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_djl, mock_get_kwargs): + """Test that config_name is forwarded to get_init_kwargs.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.env = {} + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock(spec=Model) + mock_build_djl.return_value = mock_model + + builder = ModelBuilder( + model="meta-textgeneration-llama-3-3-70b-instruct", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT + ) + builder._optimizing = False + builder.config_name = "lmi-optimized" + + builder._build_for_jumpstart() + + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized") + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart') @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')