diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab2b55634..02d91201b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,12 +11,13 @@ repos: # MyPy for static type checking - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.16.0 + rev: v1.19.1 hooks: - id: mypy additional_dependencies: [ types-requests, types-python-dateutil, + types-Pygments, pydantic, fastapi, pytest, diff --git a/strix/config/models.py b/strix/config/models.py index 7f6227f38..81c999373 100644 --- a/strix/config/models.py +++ b/strix/config/models.py @@ -59,6 +59,42 @@ def _resolve_prefixed_model( ), ) +RECOMMENDED_MODEL_NAMES = ( + "openai/gpt-5.5", + "openai/gpt-5.5-pro", + "openai/gpt-5.4", + "openai/gpt-5.4-pro", + "openai/gpt-5.3-codex", + "anthropic/claude-opus-4-8", + "anthropic/claude-sonnet-4-6", + "vertex_ai/gemini-3.1-pro-preview", + "gemini/gemini-3.1-pro-preview", + "xai/grok-4.3", + "deepseek/deepseek-v4-pro", + "deepseek/deepseek-reasoner", + "dashscope/qwen3-max-2026-01-23", + "moonshot/kimi-k2.7-code", + "moonshot/kimi-k2.6", + "mistral/mistral-medium-3-5", + "mistral/magistral-medium-latest", +) + +_RECOMMENDED_MODEL_NAME_SET = frozenset(name.lower() for name in RECOMMENDED_MODEL_NAMES) + +FRONTIER_MODEL_FAMILIES = ( + (("azure", "azure_ai", "bedrock_mantle", "openai"), ("gpt-5",)), + ( + ("anthropic", "azure_ai", "bedrock", "claude", "databricks", "snowflake", "vertex_ai"), + ("claude-opus-4", "claude-sonnet-4"), + ), + (("google", "gemini", "vertex_ai"), ("gemini-3",)), + (("xai", "x-ai"), ("grok-4",)), + (("deepseek",), ("deepseek-v4", "deepseek-r1", "deepseek-reasoner")), + (("alibaba", "dashscope", "qwen"), ("qwen3.7", "qwen3.5", "qwen3-max")), + (("moonshot", "moonshotai", "kimi"), ("kimi-k2.7", "kimi-k2.6", "kimi-k2.5")), + (("mistral", "mistralai"), ("mistral-medium-3-5", "magistral-medium")), +) + def configure_sdk_model_defaults(settings: Settings) -> None: """Apply Strix config to SDK-native defaults.""" @@ -154,6 +190,78 @@ def model_supports_reasoning(model_name: str) -> bool: return bool(entry and entry.get("supports_reasoning")) +def is_recommended_or_frontier_model(model_name: str) -> bool: + """Return whether a model is recommended or in a frontier model family.""" + name = _normalized_model_name(model_name) + if not name: + return False + if name in _RECOMMENDED_MODEL_NAME_SET: + return True + provider_name, bare_model_name = _split_model_provider(name) + return any( + _matches_frontier_family(provider_name, bare_model_name, provider_markers, prefixes) + for provider_markers, prefixes in FRONTIER_MODEL_FAMILIES + ) + + +def _normalized_model_name(model_name: str) -> str: + name = model_name.strip().lower() + for prefix in ("litellm/", "any-llm/"): + if name.startswith(prefix): + name = name[len(prefix) :] + break + return name + + +def _split_model_provider(model_name: str) -> tuple[str | None, str]: + if "/" not in model_name: + return None, model_name + provider_name, bare_model_name = model_name.rsplit("/", 1) + return provider_name, bare_model_name + + +def _matches_frontier_family( + provider_name: str | None, + model_name: str, + provider_markers: tuple[str, ...], + model_prefixes: tuple[str, ...], +) -> bool: + if not _matches_model_prefix(model_name, model_prefixes): + return False + if provider_name is None: + return True + return _contains_provider_marker( + provider_name, provider_markers, split_compound_names=True + ) or _contains_provider_marker(model_name, provider_markers) + + +def _matches_model_prefix(model_name: str, model_prefixes: tuple[str, ...]) -> bool: + return any( + candidate.startswith(prefix) + for candidate in _model_name_candidates(model_name) + for prefix in model_prefixes + ) + + +def _model_name_candidates(model_name: str) -> tuple[str, ...]: + if "." not in model_name: + return (model_name,) + suffixes = tuple( + model_name.split(".", index)[-1] for index in range(1, model_name.count(".") + 1) + ) + return (model_name, *suffixes) + + +def _contains_provider_marker( + value: str, provider_markers: tuple[str, ...], *, split_compound_names: bool = False +) -> bool: + parts = set(value.replace(".", "/").split("/")) + if split_compound_names: + for separator in ("_", "-"): + parts.update(piece for part in tuple(parts) for piece in part.split(separator)) + return any(marker in parts for marker in provider_markers) + + def is_known_openai_bare_model(model_name: str) -> bool: import litellm diff --git a/strix/core/hooks.py b/strix/core/hooks.py index 6b0d59241..64562d743 100644 --- a/strix/core/hooks.py +++ b/strix/core/hooks.py @@ -28,7 +28,10 @@ class ReportUsageHooks(RunHooks[dict[str, Any]]): def __init__(self, *, model: str, max_budget_usd: float | None = None) -> None: import math - if max_budget_usd is not None and (not math.isfinite(max_budget_usd) or max_budget_usd <= 0): + + if max_budget_usd is not None and ( + not math.isfinite(max_budget_usd) or max_budget_usd <= 0 + ): raise ValueError("max_budget_usd must be a finite number greater than 0") self._model = model self._max_budget_usd = max_budget_usd diff --git a/strix/interface/main.py b/strix/interface/main.py index 4eae05274..2ed6e5039 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -23,9 +23,11 @@ persist_current, ) from strix.config.models import ( + RECOMMENDED_MODEL_NAMES, StrixProvider, configure_sdk_model_defaults, is_known_openai_bare_model, + is_recommended_or_frontier_model, ) from strix.core.paths import run_dir_for, runtime_state_dir from strix.interface.cli import run_cli @@ -254,6 +256,32 @@ async def warm_up_llm() -> None: ) sys.exit(1) + if raw_model and not is_recommended_or_frontier_model(raw_model): + warn_text = Text() + warn_text.append("MODEL QUALITY WARNING", style="bold yellow") + warn_text.append("\n\n", style="white") + warn_text.append(f"'{raw_model}'", style="bold cyan") + warn_text.append( + " is not a recommended frontier model for Strix.\nSecurity scans work best with:\n", + style="white", + ) + for recommended_model in RECOMMENDED_MODEL_NAMES: + warn_text.append(f"• {recommended_model}\n", style="bold cyan") + warn_text.append( + "\nYou can continue, but weaker models may miss vulnerabilities " + "or produce lower-quality findings.", + style="white", + ) + console.print( + Panel( + warn_text, + title="[bold white]STRIX", + title_align="left", + border_style="yellow", + padding=(1, 2), + ), + ) + model = StrixProvider().get_model(raw_model) await asyncio.wait_for( model.get_response( @@ -310,6 +338,7 @@ def _positive_budget(value: str) -> float: except ValueError as exc: raise argparse.ArgumentTypeError(f"invalid float value: {value!r}") from exc import math + if not math.isfinite(budget) or budget <= 0: raise argparse.ArgumentTypeError("must be a finite number greater than 0") return budget diff --git a/strix/interface/tui/app.py b/strix/interface/tui/app.py index 4e43d9795..05d920d7c 100644 --- a/strix/interface/tui/app.py +++ b/strix/interface/tui/app.py @@ -83,7 +83,7 @@ def _on_key(self, event: events.Key) -> None: super()._on_key(event) - @on(TextArea.Changed) # type: ignore[misc] + @on(TextArea.Changed) # type: ignore[untyped-decorator] def _update_height(self, _event: TextArea.Changed | None = None) -> None: if not self.parent: return @@ -1549,7 +1549,7 @@ def _render_chat_content(self, msg_data: dict[str, Any]) -> Any: return AgentMessageRenderer.render_simple(content) - @on(Tree.NodeHighlighted) # type: ignore[misc] + @on(Tree.NodeHighlighted) # type: ignore[untyped-decorator] def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None: if len(self.screen_stack) > 1 or self.show_splash: return @@ -1569,7 +1569,7 @@ def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None: if agent_id: self.selected_agent_id = agent_id - @on(Tree.NodeSelected) # type: ignore[misc] + @on(Tree.NodeSelected) # type: ignore[untyped-decorator] def handle_tree_node_selected(self, event: Tree.NodeSelected) -> None: if len(self.screen_stack) > 1 or self.show_splash: return diff --git a/strix/runtime/docker_client.py b/strix/runtime/docker_client.py index 497ae2f21..d41a5a30b 100644 --- a/strix/runtime/docker_client.py +++ b/strix/runtime/docker_client.py @@ -48,7 +48,7 @@ class StrixDockerSandboxClient(DockerSandboxClient): # Host directories to bind-mount into the container, set by the docker # backend before ``create()``. Each item is ``{source, target, read_only}``. - strix_bind_mounts: list[dict[str, Any]] = [] # overridden per-instance in backends.py + strix_bind_mounts: list[dict[str, Any]] | None = None async def _create_container( self, diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 000000000..72fc523d7 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,62 @@ +"""Tests for LLM model recommendation helpers.""" + +from __future__ import annotations + +import pytest + +from strix.config.models import RECOMMENDED_MODEL_NAMES, is_recommended_or_frontier_model + + +@pytest.mark.parametrize("model_name", RECOMMENDED_MODEL_NAMES) +def test_recommended_models_are_accepted(model_name: str) -> None: + assert is_recommended_or_frontier_model(model_name) + + +def test_recommended_models_are_matched_case_insensitively() -> None: + assert is_recommended_or_frontier_model("Vertex_AI/Gemini-3-Pro-Preview") + + +@pytest.mark.parametrize( + "model_name", + [ + "gpt-5.5", + "litellm/openai/gpt-5.4-pro", + "azure_ai/gpt-5.5-pro", + "bedrock_mantle/openai.gpt-5.5", + "anthropic/claude-opus-4-8", + "anthropic.claude-opus-4-8", + "vertex_ai/claude-sonnet-4-6@default", + "any-llm/anthropic/claude-sonnet-4-6", + "vertex_ai/gemini-3.1-pro-preview", + "openrouter/google/gemini-3.1-pro-preview", + "xai/grok-4.3", + "openrouter/x-ai/grok-4", + "deepseek/deepseek-v4-pro", + "deepseek/deepseek-r1-0528", + "deepseek/deepseek-reasoner", + "dashscope/qwen3-max-2026-01-23", + "qwen3.7-max", + "moonshot/kimi-k2.6", + "kimi-k2.7-code", + "mistral/mistral-medium-3-5", + "mistral/magistral-medium-latest", + ], +) +def test_frontier_model_families_are_accepted(model_name: str) -> None: + assert is_recommended_or_frontier_model(model_name) + + +@pytest.mark.parametrize( + "model_name", + [ + "", + "openai/gpt-4.1", + "anthropic/claude-3-5-sonnet-latest", + "ollama/llama3.1", + "deepseek/deepseek-chat", + "custom-ollama/gpt-5-mini-local", + "custom-provider/claude-opus-4-local", + ], +) +def test_non_frontier_models_are_rejected(model_name: str) -> None: + assert not is_recommended_or_frontier_model(model_name)