Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sagemaker-serve/src/sagemaker/serve/model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,11 +851,12 @@ def _build_for_jumpstart(self) -> Model:
# Get JumpStart model configuration
init_kwargs = get_init_kwargs(
model_id=self.model,
model_version=self.model_version or "*",
model_version=self.model_version or "*",
region=self.region,
instance_type=self.instance_type,
tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None),
tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None)
tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None),
config_name=getattr(self, 'config_name', None),
)

# Configure image URI and environment variables
Expand Down
59 changes: 59 additions & 0 deletions sagemaker-serve/tests/unit/servers/test_model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self):
self.framework = None
self.framework_version = None
self._is_mlflow_model = False
self.config_name = None

def _deploy_local_endpoint(self, **kwargs):
return Mock()
Expand Down Expand Up @@ -816,6 +817,64 @@ def test_build_unsupported_image_uri(self, mock_init):
self.builder._build_for_jumpstart()
self.assertIn("Unsupported", str(ctx.exception))

@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
@patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources')
@patch.object(MockModelBuilderServers, '_prepare_for_mode')
@patch.object(MockModelBuilderServers, '_create_model')
def test_build_passes_config_name_to_get_init_kwargs(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init):
"""Test that config_name is forwarded to get_init_kwargs."""
mock_init_kwargs = Mock()
mock_init_kwargs.image_uri = "djl-inference:0.21.0"
mock_init_kwargs.env = {"TEST": "value"}
mock_init_kwargs.model_data = "s3://bucket/model.tar.gz"
mock_init.return_value = mock_init_kwargs
mock_djl_res.return_value = ({"config": "value"}, True)
mock_create.return_value = Mock()
self.builder.mode = Mode.LOCAL_CONTAINER
self.builder.image_uri = None
self.builder.config_name = "lmi-optimized"

self.builder._build_for_jumpstart()

mock_init.assert_called_once_with(
model_id=self.builder.model,
model_version="*",
region=self.builder.region,
instance_type=self.builder.instance_type,
tolerate_vulnerable_model=None,
tolerate_deprecated_model=None,
config_name="lmi-optimized",
)

@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
@patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources')
@patch.object(MockModelBuilderServers, '_prepare_for_mode')
@patch.object(MockModelBuilderServers, '_create_model')
def test_build_passes_none_config_name_when_not_set(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init):
"""Test that config_name defaults to None when not set."""
mock_init_kwargs = Mock()
mock_init_kwargs.image_uri = "djl-inference:0.21.0"
mock_init_kwargs.env = {}
mock_init_kwargs.model_data = "s3://bucket/model.tar.gz"
mock_init.return_value = mock_init_kwargs
mock_djl_res.return_value = ({"config": "value"}, True)
mock_create.return_value = Mock()
self.builder.mode = Mode.LOCAL_CONTAINER
self.builder.image_uri = None
self.builder.config_name = None

self.builder._build_for_jumpstart()

mock_init.assert_called_once_with(
model_id=self.builder.model,
model_version="*",
region=self.builder.region,
instance_type=self.builder.instance_type,
tolerate_vulnerable_model=None,
tolerate_deprecated_model=None,
config_name=None,
)

@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
@patch.object(MockModelBuilderServers, '_prepare_for_mode')
@patch.object(MockModelBuilderServers, '_build_for_djl_jumpstart')
Expand Down
36 changes: 32 additions & 4 deletions sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,24 +347,52 @@ def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, m
mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
mock_init_kwargs.env = {}
mock_get_kwargs.return_value = mock_init_kwargs

mock_model = Mock(spec=Model)
mock_build_djl.return_value = mock_model

builder = ModelBuilder(
model="huggingface-llm-falcon-7b",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
mode=Mode.SAGEMAKER_ENDPOINT
)
builder._optimizing = False

result = builder._build_for_jumpstart()

self.assertEqual(result, mock_model)
self.assertEqual(builder.model_server, ModelServer.DJL_SERVING)
mock_build_djl.assert_called_once()

@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
@patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_djl, mock_get_kwargs):
"""Test that config_name is forwarded to get_init_kwargs."""
mock_init_kwargs = Mock()
mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
mock_init_kwargs.env = {}
mock_get_kwargs.return_value = mock_init_kwargs

mock_model = Mock(spec=Model)
mock_build_djl.return_value = mock_model

builder = ModelBuilder(
model="meta-textgeneration-llama-3-3-70b-instruct",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
mode=Mode.SAGEMAKER_ENDPOINT
)
builder._optimizing = False
builder.config_name = "lmi-optimized"

builder._build_for_jumpstart()

mock_get_kwargs.assert_called_once()
call_kwargs = mock_get_kwargs.call_args
self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized")

@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
@patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
Expand Down
Loading