Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 226 additions & 13 deletions sagemaker-core/src/sagemaker/core/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@
fi

printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@"
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@"
else
printf "INFO: No conda env provided. Invoking remote function\\n"
printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n"
python -m sagemaker.train.remote_function.invoke_function "$@"
printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n"
python -m sagemaker.core.remote_function.invoke_function "$@"
fi
"""

Expand Down Expand Up @@ -234,14 +234,14 @@
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \

python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"

python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
else
Expand All @@ -259,15 +259,15 @@
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"

mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"

python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
else
Expand Down Expand Up @@ -320,18 +320,18 @@
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.train.remote_function.invoke_function \\n"
-m sagemaker.core.remote_function.invoke_function \\n"

$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.train.remote_function.invoke_function "$@"
-m sagemaker.core.remote_function.invoke_function "$@"
else
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n"
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n"

torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@"
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@"
fi
"""

Expand Down Expand Up @@ -1259,7 +1259,215 @@ def _prepare_and_upload_runtime_scripts(
return upload_path


def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
def _decrement_version(version_str: str) -> str:
"""Decrement a version string by one minor or patch version.

Rules:
- If patch version is 0 (e.g., 3.2.0), decrement minor: 3.2.0 -> 3.1.0
- If patch version is not 0 (e.g., 3.1.2), decrement patch: 3.1.2 -> 3.1.1

Args:
version_str: Version string (e.g., "3.2.0")

Returns:
Decremented version string
"""
from packaging import version as pkg_version

try:
parsed = pkg_version.parse(version_str)
major = parsed.major
minor = parsed.minor
patch = parsed.micro

if patch == 0:
# Decrement minor version
minor = max(0, minor - 1)
else:
# Decrement patch version
patch = max(0, patch - 1)

return f"{major}.{minor}.{patch}"
except Exception:
return version_str


def _resolve_version_from_specifier(specifier_str: str) -> str:
"""Resolve the version to check based on upper bounds.

Upper bounds take priority. If upper bound is <4.0.0, it's safe (V3 only).
If no upper bound exists, it's safe (unbounded).
If the decremented upper bound is less than a lower bound, use the lower bound.

Args:
specifier_str: Version specifier string (e.g., ">=3.2.0", "<3.2.0", "==3.1.0")

Returns:
The resolved version string to check, or None if safe
"""
import re
from packaging import version as pkg_version

# Handle exact version pinning (==)
match = re.search(r'==\s*([\d.]+)', specifier_str)
if match:
return match.group(1)

# Extract lower bounds for comparison
lower_bounds = []
for match in re.finditer(r'>=\s*([\d.]+)', specifier_str):
lower_bounds.append(match.group(1))

# Handle upper bounds - find the most restrictive one
upper_bounds = []

# Find all <= bounds
for match in re.finditer(r'<=\s*([\d.]+)', specifier_str):
upper_bounds.append(('<=', match.group(1)))

# Find all < bounds
for match in re.finditer(r'<\s*([\d.]+)', specifier_str):
upper_bounds.append(('<', match.group(1)))

if upper_bounds:
# Sort by version to find the most restrictive (lowest) upper bound
upper_bounds.sort(key=lambda x: pkg_version.parse(x[1]))
operator, version = upper_bounds[0]

# Special case: if upper bound is <4.0.0, it's safe (V3 only)
try:
parsed_upper = pkg_version.parse(version)
if operator == '<' and parsed_upper.major == 4 and parsed_upper.minor == 0 and parsed_upper.micro == 0:
# <4.0.0 means V3 only, which is safe
return None
except Exception:
pass

resolved_version = version
if operator == '<':
resolved_version = _decrement_version(version)

# If we have a lower bound and the resolved version is less than it, use the lower bound
if lower_bounds:
try:
resolved_parsed = pkg_version.parse(resolved_version)
for lower_bound_str in lower_bounds:
lower_parsed = pkg_version.parse(lower_bound_str)
if resolved_parsed < lower_parsed:
resolved_version = lower_bound_str
except Exception:
pass

return resolved_version

# For lower bounds only (>=, >), we don't check
return None


def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
"""Check if the sagemaker version requirement uses incompatible hashing.

Raises ValueError if the requirement would install a version that uses HMAC hashing
(which is incompatible with the current SHA256-based integrity checks).

Args:
sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=3.2.0")

Raises:
ValueError: If the requirement would install a version using HMAC hashing
"""
import re
from packaging import version as pkg_version

match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE)
if not match:
return

specifier_str = match.group(1).strip()

# Resolve the version that would be installed
resolved_version_str = _resolve_version_from_specifier(specifier_str)
if not resolved_version_str:
# No upper bound or exact version, so we can't determine if it's bad
return

try:
resolved_version = pkg_version.parse(resolved_version_str)
except Exception:
return

# Define HMAC thresholds for each major version
v2_hmac_threshold = pkg_version.parse("2.256.0")
v3_hmac_threshold = pkg_version.parse("3.2.0")

# Check if the resolved version uses HMAC hashing
uses_hmac = False
if resolved_version.major == 2 and resolved_version < v2_hmac_threshold:
uses_hmac = True
elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold:
uses_hmac = True

if uses_hmac:
raise ValueError(
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
f"could install a version using HMAC-based integrity checks which are incompatible "
f"with the current SHA256-based integrity checks. Please update to "
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
)


def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
"""Ensure sagemaker>=3.2.0 is in the dependencies.

This function ensures that the remote environment has a compatible version of sagemaker
that includes the fix for the HMAC key security issue. Versions < 3.2.0 use HMAC-based
integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable.
Versions >= 3.2.0 use SHA256-based integrity checks which are secure and don't require
the secret key.

If no dependencies are provided, creates a temporary requirements.txt with sagemaker.
If dependencies are provided, appends sagemaker if not already present.

Args:
local_dependencies_path: Path to user's dependencies file or None

Returns:
Path to the dependencies file (created or modified)

Raises:
ValueError: If user has pinned sagemaker to a version using HMAC hashing
"""
import tempfile

SAGEMAKER_MIN_VERSION = "sagemaker>=3.2.0,<4.0.0"

if local_dependencies_path is None:
fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_")
os.close(fd)

with open(req_file, "w") as f:
f.write(f"{SAGEMAKER_MIN_VERSION}\n")
logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION)
return req_file

if local_dependencies_path.endswith(".txt"):
with open(local_dependencies_path, "r") as f:
content = f.read()

if "sagemaker" in content.lower():
for line in content.split('\n'):
if 'sagemaker' in line.lower():
_check_sagemaker_version_compatibility(line.strip())
break
else:
with open(local_dependencies_path, "a") as f:
f.write(f"\n{SAGEMAKER_MIN_VERSION}\n")
logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION)

return local_dependencies_path


def _generate_input_data_config(job_settings, s3_base_uri):
"""Generates input data config"""
from sagemaker.core.workflow.utilities import load_step_compilation_context

Expand Down Expand Up @@ -1288,6 +1496,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):

local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)

# Ensure sagemaker dependency is included to prevent version mismatch issues
# Resolves issue where computing hash for integrity check changed in 3.2.0
local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path)
job_settings.dependencies = local_dependencies_path

if step_compilation_context:
with _tmpdir() as tmp_dir:
script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts(
Expand Down
Loading
Loading