Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a545b1a
add support for on-prem
TarunRavikumar Oct 28, 2025
ca0a703
clean up on-prem artificats
TarunRavikumar Oct 28, 2025
d7e5fab
add back comments from initial code
TarunRavikumar Oct 28, 2025
9b45d81
fix lint
TarunRavikumar Oct 28, 2025
f23d823
use ecr image repo:tag directly
TarunRavikumar Nov 12, 2025
20a8dc2
fix: isort import ordering
TarunRavikumar Dec 11, 2025
cf4a411
fix: remove unused infra_config import
TarunRavikumar Dec 11, 2025
d19e6f2
fix: mypy type annotation errors
TarunRavikumar Dec 11, 2025
f72fea2
Merge branch 'main' into tr/onprem
TarunRavikumar Dec 12, 2025
3bea65a
fix: remove type annotation causing mypy no-redef error
TarunRavikumar Dec 12, 2025
0ad17fb
fix: mypy type errors in s3_utils.py and io.py - use botocore.config.…
TarunRavikumar Dec 12, 2025
48eaac4
fix: mypy typeddict-item errors - use broad type ignore
TarunRavikumar Dec 12, 2025
5257762
fix: update test mocks to use get_s3_resource from s3_utils
TarunRavikumar Dec 12, 2025
412fe41
test: add unit tests for s3_utils, onprem_docker_repository, and onpr…
TarunRavikumar Dec 12, 2025
5b3f796
style: format test files with black
TarunRavikumar Dec 12, 2025
8ff1eea
refactor: use filesystem_gateway abstraction for S3 operations
TarunRavikumar Dec 15, 2025
aaca0b8
fix: deduplicate S3 client config by using centralized s3_utils
TarunRavikumar Dec 15, 2025
2687232
fix: add pagination to list_objects to handle >1000 objects
TarunRavikumar Dec 15, 2025
82e31d6
fix: make OnPremDockerRepository.get_image_url consistent with ECR/ACR
TarunRavikumar Dec 15, 2025
70a5633
refactor: add explicit on-prem branches in dependencies.py for clarity
TarunRavikumar Dec 15, 2025
69f87a9
feat: implement Redis LLEN for queue depth in OnPremQueueEndpointReso…
TarunRavikumar Dec 15, 2025
4c85f12
fix: replace mutable default argument with None in _get_client
TarunRavikumar Dec 15, 2025
52cc826
refactor: extract inline import to module-level helper function
TarunRavikumar Dec 15, 2025
6a5ffcc
fix: reduce excessive debug logging in s3_utils
TarunRavikumar Dec 15, 2025
f27d817
chore: remove unused TYPE_CHECKING import
TarunRavikumar Dec 15, 2025
2e20e55
fix: make Dockerfile multi-arch compatible for ARM/AMD64
TarunRavikumar Dec 15, 2025
6c62b72
style: fix black formatting in test_onprem_queue_endpoint_resource_de…
TarunRavikumar Dec 15, 2025
7f434ee
fix: restore AWS_PROFILE env var fallback in s3_utils
TarunRavikumar Dec 15, 2025
982ecb1
fix: correct isort ordering in s3_filesystem_gateway.py
TarunRavikumar Dec 15, 2025
df98ddc
fix: use Literal type for s3 addressing_style to satisfy mypy
TarunRavikumar Dec 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions model-engine/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ RUN apt-get update && apt-get install -y \
telnet \
&& rm -rf /var/lib/apt/lists/*

RUN curl -Lo /bin/aws-iam-authenticator https://github.com/kubernetes-sigs/aws-iam-authenticator/releases/download/v0.5.9/aws-iam-authenticator_0.5.9_linux_amd64
RUN chmod +x /bin/aws-iam-authenticator
# Install aws-iam-authenticator (architecture-aware)
RUN ARCH=$(dpkg --print-architecture) && \
if [ "$ARCH" = "arm64" ]; then \
curl -Lo /bin/aws-iam-authenticator https://github.com/kubernetes-sigs/aws-iam-authenticator/releases/download/v0.5.9/aws-iam-authenticator_0.5.9_linux_arm64; \
else \
curl -Lo /bin/aws-iam-authenticator https://github.com/kubernetes-sigs/aws-iam-authenticator/releases/download/v0.5.9/aws-iam-authenticator_0.5.9_linux_amd64; \
fi && \
chmod +x /bin/aws-iam-authenticator

# Install kubectl
RUN curl -LO "https://dl.k8s.io/release/v1.23.13/bin/linux/amd64/kubectl" \
&& chmod +x kubectl \
&& mv kubectl /usr/local/bin/kubectl
# Install kubectl (architecture-aware)
RUN ARCH=$(dpkg --print-architecture) && \
curl -LO "https://dl.k8s.io/release/v1.23.13/bin/linux/${ARCH}/kubectl" && \
chmod +x kubectl && \
mv kubectl /usr/local/bin/kubectl

# Pin pip version
RUN pip install pip==24.2
Expand Down
65 changes: 38 additions & 27 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
from model_engine_server.infra.gateways.resources.onprem_queue_endpoint_resource_delegate import (
OnPremQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
)
Expand All @@ -114,6 +117,7 @@
FakeDockerRepository,
LiveTokenizerRepository,
LLMFineTuneRepository,
OnPremDockerRepository,
RedisModelEndpointCacheRepository,
S3FileLLMFineTuneEventsRepository,
S3FileLLMFineTuneRepository,
Expand Down Expand Up @@ -225,6 +229,8 @@ def _get_external_interfaces(
queue_delegate = FakeQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "onprem":
queue_delegate = OnPremQueueEndpointResourceDelegate()
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -238,6 +244,9 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "onprem":
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
elif infra_config().celery_broker_type_redis:
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
Expand Down Expand Up @@ -274,16 +283,17 @@ def _get_external_interfaces(
monitoring_metrics_gateway=monitoring_metrics_gateway,
use_asyncio=(not CIRCLECI),
)
filesystem_gateway = (
ABSFilesystemGateway()
if infra_config().cloud_provider == "azure"
else S3FilesystemGateway()
)
llm_artifact_gateway = (
ABSLLMArtifactGateway()
if infra_config().cloud_provider == "azure"
else S3LLMArtifactGateway()
)
filesystem_gateway: FilesystemGateway
llm_artifact_gateway: LLMArtifactGateway
if infra_config().cloud_provider == "azure":
filesystem_gateway = ABSFilesystemGateway()
llm_artifact_gateway = ABSLLMArtifactGateway()
elif infra_config().cloud_provider == "onprem":
filesystem_gateway = S3FilesystemGateway() # Uses MinIO via s3_utils
llm_artifact_gateway = S3LLMArtifactGateway() # Uses MinIO via s3_utils
else:
filesystem_gateway = S3FilesystemGateway()
llm_artifact_gateway = S3LLMArtifactGateway()
model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
filesystem_gateway=filesystem_gateway
)
Expand Down Expand Up @@ -323,23 +333,20 @@ def _get_external_interfaces(
cron_job_gateway = LiveCronJobGateway()

llm_fine_tune_repository: LLMFineTuneRepository
llm_fine_tune_events_repository: LLMFineTuneEventsRepository
file_path = os.getenv(
"CLOUD_FILE_LLM_FINE_TUNE_REPOSITORY",
hmi_config.cloud_file_llm_fine_tune_repository,
)
if infra_config().cloud_provider == "azure":
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(file_path=file_path)
llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository()
elif infra_config().cloud_provider == "onprem":
llm_fine_tune_repository = S3FileLLMFineTuneRepository(file_path=file_path) # Uses MinIO
llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository() # Uses MinIO
else:
llm_fine_tune_repository = S3FileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_events_repository = (
ABSFileLLMFineTuneEventsRepository()
if infra_config().cloud_provider == "azure"
else S3FileLLMFineTuneEventsRepository()
)
llm_fine_tune_repository = S3FileLLMFineTuneRepository(file_path=file_path)
llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository()
llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService(
docker_image_batch_job_gateway=docker_image_batch_job_gateway,
docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository,
Expand All @@ -350,17 +357,21 @@ def _get_external_interfaces(
docker_image_batch_job_gateway=docker_image_batch_job_gateway
)

file_storage_gateway = (
ABSFileStorageGateway()
if infra_config().cloud_provider == "azure"
else S3FileStorageGateway()
)
file_storage_gateway: FileStorageGateway
if infra_config().cloud_provider == "azure":
file_storage_gateway = ABSFileStorageGateway()
elif infra_config().cloud_provider == "onprem":
file_storage_gateway = S3FileStorageGateway() # Uses MinIO via s3_utils
else:
file_storage_gateway = S3FileStorageGateway()

docker_repository: DockerRepository
if CIRCLECI:
docker_repository = FakeDockerRepository()
elif infra_config().docker_repo_prefix.endswith("azurecr.io"):
elif infra_config().cloud_provider == "azure":
docker_repository = ACRDockerRepository()
elif infra_config().cloud_provider == "onprem":
docker_repository = OnPremDockerRepository()
else:
docker_repository = ECRDockerRepository()

Expand Down
20 changes: 14 additions & 6 deletions model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,29 @@ def from_yaml(cls, yaml_path):

@property
def cache_redis_url(self) -> str:
cloud_provider = infra_config().cloud_provider

if cloud_provider == "onprem":
if self.cache_redis_aws_url:
logger.info("On-prem deployment using cache_redis_aws_url")
return self.cache_redis_aws_url
redis_host = os.getenv("REDIS_HOST", "redis")
redis_port = getattr(infra_config(), "redis_port", 6379)
return f"redis://{redis_host}:{redis_port}/0"

if self.cache_redis_aws_url:
assert infra_config().cloud_provider == "aws", "cache_redis_aws_url is only for AWS"
assert cloud_provider == "aws", "cache_redis_aws_url is only for AWS"
if self.cache_redis_aws_secret_name:
logger.warning(
"Both cache_redis_aws_url and cache_redis_aws_secret_name are set. Using cache_redis_aws_url"
)
return self.cache_redis_aws_url
elif self.cache_redis_aws_secret_name:
assert (
infra_config().cloud_provider == "aws"
), "cache_redis_aws_secret_name is only for AWS"
creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role
assert cloud_provider == "aws", "cache_redis_aws_secret_name is only for AWS"
creds = get_key_file(self.cache_redis_aws_secret_name)
return creds["cache-url"]

assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
assert self.cache_redis_azure_host and cloud_provider == "azure"
username = os.getenv("AZURE_OBJECT_ID")
token = DefaultAzureCredential().get_token("https://redis.azure.com/.default")
password = token.token
Expand Down
10 changes: 4 additions & 6 deletions model-engine/model_engine_server/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
import os
from typing import Any

import boto3
import smart_open
from model_engine_server.core.config import infra_config


def open_wrapper(uri: str, mode: str = "rt", **kwargs):
client: Any
cloud_provider: str
# This follows the 5.1.0 smart_open API
try:
cloud_provider = infra_config().cloud_provider
except Exception:
cloud_provider = "aws"

if cloud_provider == "azure":
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
Expand All @@ -25,9 +23,9 @@ def open_wrapper(uri: str, mode: str = "rt", **kwargs):
DefaultAzureCredential(),
)
else:
profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE"))
session = boto3.Session(profile_name=profile_name)
client = session.client("s3")
from model_engine_server.infra.gateways.s3_utils import get_s3_client

client = get_s3_client(kwargs)

transport_params = {"client": client}
return smart_open.open(uri, mode, transport_params=transport_params)
31 changes: 21 additions & 10 deletions model-engine/model_engine_server/core/celery/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,17 +531,28 @@ def _get_backend_url_and_conf(
backend_url = get_redis_endpoint(1)
elif backend_protocol == "s3":
backend_url = "s3://"
if aws_role is None:
aws_session = session(infra_config().profile_ml_worker)
if infra_config().cloud_provider == "aws":
if aws_role is None:
aws_session = session(infra_config().profile_ml_worker)
else:
aws_session = session(aws_role)
out_conf_changes.update(
{
"s3_boto3_session": aws_session,
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
else:
aws_session = session(aws_role)
out_conf_changes.update(
{
"s3_boto3_session": aws_session,
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
logger.info(
"Non-AWS deployment, using environment variables for S3 backend credentials"
)
out_conf_changes.update(
{
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
elif backend_protocol == "abs":
backend_url = f"azureblockblob://{os.getenv('ABS_ACCOUNT_NAME')}"
else:
Expand Down
72 changes: 72 additions & 0 deletions model-engine/model_engine_server/core/configs/onprem.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# On-premise deployment configuration
# This configuration file provides defaults for on-prem deployments
# Many values can be overridden via environment variables

cloud_provider: "onprem"
env: "production" # Can be: production, staging, development, local
k8s_cluster_name: "onprem-cluster"
dns_host_domain: "ml.company.local"
default_region: "us-east-1" # Placeholder for compatibility with cloud-agnostic code

# ====================
# Object Storage (MinIO/S3-compatible)
# ====================
s3_bucket: "model-engine"
# S3 endpoint URL - can be overridden by S3_ENDPOINT_URL env var
# Examples: "https://minio.company.local", "http://minio-service:9000"
s3_endpoint_url: "" # Set via S3_ENDPOINT_URL env var if not specified here
# MinIO requires path-style addressing (bucket in URL path, not subdomain)
s3_addressing_style: "path"

# ====================
# Redis Configuration
# ====================
# Redis is used for:
# - Celery task queue broker
# - Model endpoint caching
# - Inference autoscaling metrics
redis_host: "" # Set via REDIS_HOST env var (e.g., "redis.company.local" or "redis-service")
redis_port: 6379
# Whether to use Redis as Celery broker (true for on-prem)
celery_broker_type_redis: true

# ====================
# Celery Configuration
# ====================
# Backend protocol: "redis" for on-prem (not "s3" or "abs")
celery_backend_protocol: "redis"

# ====================
# Database Configuration
# ====================
# Database connection settings (credentials from environment variables)
# DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
db_host: "postgres" # Default hostname, can be overridden by DB_HOST env var
db_port: 5432
db_name: "llm_engine"
db_engine_pool_size: 20
db_engine_max_overflow: 10
db_engine_echo: false
db_engine_echo_pool: false
db_engine_disconnect_strategy: "pessimistic"

# ====================
# Docker Registry Configuration
# ====================
# Docker registry prefix for container images
# Examples: "registry.company.local", "harbor.company.local/ml-platform"
# Leave empty if using full image paths directly
docker_repo_prefix: "registry.company.local"

# ====================
# Monitoring & Observability
# ====================
# Prometheus server address for metrics (optional)
# prometheus_server_address: "http://prometheus:9090"

# ====================
# Not applicable for on-prem (kept for compatibility)
# ====================
ml_account_id: "onprem"
profile_ml_worker: "default"
profile_ml_inference_worker: "default"
14 changes: 12 additions & 2 deletions model-engine/model_engine_server/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,23 @@ def get_engine_url(
key_file = get_key_file_name(env) # type: ignore
logger.debug(f"Using key file {key_file}")

if infra_config().cloud_provider == "azure":
if infra_config().cloud_provider == "onprem":
user = os.environ.get("DB_USER", "postgres")
password = os.environ.get("DB_PASSWORD", "postgres")
host = os.environ.get("DB_HOST_RO") or os.environ.get("DB_HOST", "localhost")
port = os.environ.get("DB_PORT", "5432")
dbname = os.environ.get("DB_NAME", "llm_engine")
logger.info(f"Connecting to db {host}:{port}, name {dbname}")

engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"

elif infra_config().cloud_provider == "azure":
client = SecretClient(
vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net",
credential=DefaultAzureCredential(),
)
db = client.get_secret(key_file).value
user = os.environ.get("AZURE_IDENTITY_NAME")
user = os.environ.get("AZURE_IDENTITY_NAME", "")
token = DefaultAzureCredential().get_token(
"https://ossrdbms-aad.database.windows.net/.default"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from typing_extensions import Literal


def _is_onprem_deployment() -> bool:
from model_engine_server.core.config import infra_config

return infra_config().cloud_provider == "onprem"


class ModelBundlePackagingType(str, Enum):
"""
The canonical list of possible packaging types for Model Bundles.
Expand Down Expand Up @@ -71,10 +77,15 @@ def validate_fields_present_for_framework_type(cls, field_values):
"type was selected."
)
else: # field_values["framework_type"] == ModelBundleFramework.CUSTOM:
assert field_values["ecr_repo"] and field_values["image_tag"], (
"Expected `ecr_repo` and `image_tag` to be non-null because the custom framework "
assert field_values["image_tag"], (
"Expected `image_tag` to be non-null because the custom framework "
"type was selected."
)
if not field_values.get("ecr_repo") and not _is_onprem_deployment():
raise ValueError(
"Expected `ecr_repo` to be non-null for custom framework. "
"For on-prem deployments, ecr_repo can be omitted to use direct image references."
)
return field_values

model_config = ConfigDict(from_attributes=True)
Expand Down
Loading