Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@ exclude: ^(scratchpad/)

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.5
rev: v0.14.14
hooks:
- id: ruff-format
name: "Ruff formatter"
args: [--config=pyproject.toml]
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
types_or: [python, jupyter]
- id: ruff
name: "Ruff linter"
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
files: '^(mellea).*\.(py|ipynb)$'
types_or: [python, jupyter]

- repo: local
hooks:
- id: mypy
name: MyPy
entry: uv run --no-sync mypy mellea
entry: uv run --no-sync mypy .
pass_filenames: false
language: system
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
types_or: [python, jupyter]

- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.7.8
Expand Down
2 changes: 1 addition & 1 deletion cli/decompose/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing_extensions import NotRequired

from mellea import MelleaSession
from mellea.backends import ModelOption
from mellea.backends.ollama import OllamaModelBackend
from mellea.backends.openai import OpenAIBackend
from mellea.backends import ModelOption

from .prompt_modules import (
constraint_extractor,
Expand Down
3 changes: 2 additions & 1 deletion cli/eval/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Use the eval command for LLM-as-a-judge evaluation, given a (set of) test file(s) consisting of prompts, instructions, and optionally, targets.
Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed."""
Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed.
"""

import typer

Expand Down
37 changes: 21 additions & 16 deletions cli/eval/runner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
import re
from pathlib import Path
from typing import List

from rich.console import Console
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn

import mellea
from mellea.backends import ModelOption
from mellea.backends.backend import Backend
from mellea.core import ModelOutputThunk
from mellea.stdlib.components import SimpleComponent
from mellea.stdlib.components.unit_test_eval import TestBasedEval
from mellea.backends import ModelOption

from rich.console import Console
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn

console = Console()

Expand Down Expand Up @@ -78,7 +79,6 @@ def create_session(
backend: str, model: str | None, max_tokens: int | None
) -> mellea.MelleaSession:
"""Create a mellea session with the specified backend and model."""

model_id = None
if model:
if model.isupper() or "_" in model:
Expand All @@ -93,6 +93,7 @@ def create_session(

try:
backend_lower = backend.lower()
backend_instance: Backend

if backend_lower == "ollama":
from mellea.backends.ollama import OllamaModelBackend
Expand Down Expand Up @@ -130,7 +131,7 @@ def create_session(
from mellea.backends.litellm import LiteLLMBackend

backend_instance = LiteLLMBackend(
model_id=model_id,
model_id=str(model_id),
model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
)

Expand All @@ -153,7 +154,7 @@ def create_session(


def run_evaluations(
test_files: List[str],
test_files: list[str],
backend: str,
model: str | None,
max_gen_tokens: int | None,
Expand All @@ -173,14 +174,14 @@ def run_evaluations(
"instructions": a set (in string form) of requirements which the generation should follow; the judge will evaluate if these are satisfied
"examples": a list of entries containing an input_id, an input(prompt), and a list of targets. Each input may have multiple (or no) targets; inputs and targets are in messages format.
"""
all_test_evals: List[TestBasedEval] = []
all_test_evals: list[TestBasedEval] = []

for test_file in test_files:
try:
test_evals = TestBasedEval.from_json_file(test_file)
all_test_evals.extend(test_evals)
console.print(f"Loaded {len(test_evals)} test evaluations from {test_file}")
except Exception as e:
except Exception:
console.print(f"Error loading {test_file}")

if not all_test_evals:
Expand All @@ -195,8 +196,11 @@ def run_evaluations(
console.print(f"Judge model: {judge_model}")

m = create_session(backend=backend, model=model, max_tokens=max_gen_tokens)
# Use same backend as generator if judge_backend not specified
judge_session = create_session(
backend=judge_backend, model=judge_model, max_tokens=max_judge_tokens
backend=judge_backend if judge_backend else backend,
model=judge_model,
max_tokens=max_judge_tokens,
)

all_results = []
Expand Down Expand Up @@ -240,12 +244,13 @@ def execute_test_eval(
For each input in the test, generate a response using generation_session
Then, after all inputs are processed, validate using judge_session.
"""

input_results = []

# for all inputs, generate responses with generator
for idx, input_text in enumerate(test_eval.inputs):
result: ModelOutputThunk = generation_session.act(input_text)
result: ModelOutputThunk = generation_session.act(
SimpleComponent(instruction=input_text)
)
model_output = str(result)

targets_for_input = (
Expand All @@ -267,7 +272,7 @@ def execute_test_eval(
input_text=input_text,
model_output=model_output,
validation_passed=passed,
score=score,
score=score if score is not None else 0,
validation_reason=justification,
)
input_results.append(input_result)
Expand Down Expand Up @@ -301,7 +306,7 @@ def parse_judge_output(judge_output: str):
return None, judge_output


def save_results(results: List[TestEvalResult], output_path: str, output_format: str):
def save_results(results: list[TestEvalResult], output_path: str, output_format: str):
output_path_obj = Path(output_path)
if output_path_obj.suffix != f".{output_format}":
output_path_obj = Path(f"{output_path}.{output_format}")
Expand Down Expand Up @@ -333,7 +338,7 @@ def save_results(results: List[TestEvalResult], output_path: str, output_format:
console.print(f"Results saved to {output_path}")


def summary_stats(results: List[TestEvalResult]):
def summary_stats(results: list[TestEvalResult]):
total_inputs = sum(r.total_count for r in results)
passed_inputs = sum(r.passed_count for r in results)
overall_pass_rate = passed_inputs / total_inputs if total_inputs > 0 else 0.0
Expand Down
2 changes: 1 addition & 1 deletion cli/m.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from cli.alora.commands import alora_app
from cli.decompose import app as decompose_app
from cli.serve.app import serve
from cli.eval.commands import eval_app
from cli.serve.app import serve

cli = typer.Typer(name="m", no_args_is_help=True)

Expand Down
Empty file added docs/__init__.py
Empty file.
Empty file added docs/examples/__init__.py
Empty file.
9 changes: 5 additions & 4 deletions docs/examples/aLora/101_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import time

from mellea import MelleaSession
from mellea.backends.aloras.huggingface.granite_aloras import HFConstraintAlora
from mellea.backends.cache import SimpleLRUCache
from mellea.backends.huggingface import LocalHFBackend
from mellea.stdlib.base import ChatContext, GenerateLog
from mellea.stdlib.requirement import ALoraRequirement, Requirement

from mellea import MelleaSession
from mellea.backends.cache import SimpleLRUCache
from mellea.backends.huggingface import LocalHFBackend

# Define a backend and add the constraint aLora
backend = LocalHFBackend(
model_id="ibm-granite/granite-3.2-8b-instruct", cache=SimpleLRUCache(5)
Expand All @@ -21,7 +22,7 @@
backend=backend,
)

backend.add_alora(custom_stembolt_failure_constraint)
backend.add_alora(custom_stembolt_failure_constraint) # type: ignore[attr-defined]

# Create M session
m = MelleaSession(backend, ctx=ChatContext())
Expand Down
7 changes: 4 additions & 3 deletions docs/examples/agents/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import json
from collections.abc import Callable
from enum import Enum
from typing import Literal

import pydantic
Expand Down Expand Up @@ -84,9 +85,9 @@ def call_tool(self, tool: ReactTool, kwargs_json: str):

def tool_name_schema(self):
names = self.tool_names()
fields = dict()
fields["tool"] = Literal[*names]
return pydantic.create_model("ToolSelectionSchema", **fields)
# Python 3.10 compatible: use Enum instead of Literal[*names] (requires 3.11+)
ToolEnum = Enum("ToolEnum", {name: name for name in names})
return pydantic.create_model("ToolSelectionSchema", tool=(ToolEnum, ...))

def get_tool_from_schema(self, content: str):
schema = self.tool_name_schema()
Expand Down
7 changes: 4 additions & 3 deletions docs/examples/agents/react_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import json
from collections.abc import Callable
from enum import Enum
from typing import Literal

import pydantic
Expand Down Expand Up @@ -81,9 +82,9 @@ def call_tool(self, tool: ReactTool, kwargs_json: str):

def tool_name_schema(self):
names = self.tool_names()
fields = dict()
fields["tool"] = Literal[*names]
return pydantic.create_model("ToolSelectionSchema", **fields)
# Python 3.10 compatible: use Enum instead of Literal[*names] (requires 3.11+)
ToolEnum = Enum("ToolEnum", {name: name for name in names})
return pydantic.create_model("ToolSelectionSchema", tool=(ToolEnum, ...))

def get_tool_from_schema(self, content: str):
schema = self.tool_name_schema()
Expand Down
7 changes: 3 additions & 4 deletions docs/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def pytest_ignore_collect(collection_path, path, config):
# Extract markers and check if we should skip
try:
markers = _extract_markers_from_file(collection_path)
should_skip, reason = _should_skip_collection(markers)
should_skip, _reason = _should_skip_collection(markers)
if should_skip:
# Return True to ignore this file completely
return True
Expand All @@ -233,7 +233,7 @@ def pytest_pycollect_makemodule(module_path, path, parent):
and "examples" in module_path.parts
):
# Check for optional imports
should_skip, reason = _check_optional_imports(module_path)
should_skip, _reason = _check_optional_imports(module_path)
if should_skip:
# Add to skip list and return None to prevent module creation
examples_to_skip.add(module_path.name)
Expand All @@ -257,7 +257,7 @@ def pytest_collect_file(parent: pytest.Dir, file_path: pathlib.PosixPath):
return

# Check for optional imports before creating ExampleFile
should_skip, reason = _check_optional_imports(file_path)
should_skip, _reason = _check_optional_imports(file_path)
if should_skip:
return None

Expand Down Expand Up @@ -344,7 +344,6 @@ def pytest_runtest_setup(item):
gh_run = int(os.environ.get("CICD", 0))

# Get config options (all default to False for examples)
ignore_all = False
ignore_gpu = False
ignore_ram = False
ignore_ollama = False
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/context/contexts_with_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
print(f"Total Generation Attempts: {len(res.sample_generations)}")
print()

print(f"Getting index of another result.")
print("Getting index of another result.")
index = 0 # Just choose the first one.

print(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from typing import Literal

from mellea import generative, start_session
from mellea.core import Requirement
from mellea.stdlib.components.genslot import PreconditionException
from mellea.stdlib.requirements import simple_validate
from mellea.core import Requirement
from mellea.stdlib.sampling.base import RejectionSamplingStrategy


@generative
def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]:
"""Classify the sentiment of the text."""
...


if __name__ == "__main__":
Expand All @@ -30,8 +29,8 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]:
)

print(
f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```"
) # type: ignore
f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```" # type: ignore[index]
)
# Prompt to the model looked like:
# ```
# Your task is to imitate the output of the following function for the given arguments.
Expand Down Expand Up @@ -65,7 +64,7 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]:
],
)
except PreconditionException as e:
print(f"exception: {str(e)}")
print(f"exception: {e!s}")

# Look at why the precondition validation failed.
print("Failure reasons:")
Expand Down
13 changes: 7 additions & 6 deletions docs/examples/image_text_models/vision_litellm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Examples of using vision models with LiteLLM backend."""

import os
import pathlib

import litellm
from PIL import Image
Expand All @@ -11,7 +12,6 @@
from mellea.backends.litellm import LiteLLMBackend
from mellea.backends.openai import OpenAIBackend
from mellea.core import ImageBlock
import pathlib

# use LiteLLM to talk to Ollama or anthropic or.....
m = MelleaSession(LiteLLMBackend("ollama/granite3.2-vision"))
Expand All @@ -28,17 +28,18 @@
# test with PIL image
res_instruct = m.instruct(
"Is there a person on the image? Is the subject in the image smiling?",
images=[test_pil],
images=[test_pil], # type: ignore[arg-type]
)
print(f"Test with PIL and instruct: \n{str(res_instruct)}\n-----")
print(f"Test with PIL and instruct: \n{res_instruct!s}\n-----")
# print(m.last_prompt())

# with PIL image and using m.chat
res_chat = m.chat(
"How many eyes can you identify in the image? Explain.", images=[test_pil]
"How many eyes can you identify in the image? Explain.",
images=[test_pil], # type: ignore[arg-type]
)
print(f"Test with PIL and chat: \n{str(res_chat.content)}\n-----")
print(f"Test with PIL and chat: \n{res_chat.content!s}\n-----")

# and now without images again...
res_empty = m.instruct("How many eyes can you identify in the image?", images=[])
print(f"Test without image: \n{str(res_empty)}\n-----")
print(f"Test without image: \n{res_empty!s}\n-----")
3 changes: 2 additions & 1 deletion docs/examples/image_text_models/vision_ollama_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Example of using Ollama with vision models with linear context."""

import pathlib

from PIL import Image

from mellea import start_session
Expand All @@ -16,7 +17,7 @@
test_pil = Image.open(image_path)

# ask a question about the image
res = m.instruct("Is the subject in the image smiling?", images=[test_pil])
res = m.instruct("Is the subject in the image smiling?", images=[test_pil]) # type: ignore[arg-type]
print(f"Result:{res!s}")

# This instruction should refer to the first image.
Expand Down
Loading
Loading