-
Notifications
You must be signed in to change notification settings - Fork 73
fix: astream output #358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix: astream output #358
Conversation
|
The PR description has been updated. Please fill out the template for your PR to be reviewed. |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🟢 Enforce conventional commitWonderful, this rule succeeded.Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/
|
psschwei
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nit, but otherwise lgtm
mellea/core/base.py
Outdated
| f"Cannot use `ModelOutputThunk.astream()` when the generate function is using `{self._generate_type.name}`" | ||
| ) | ||
| # Beginning value | ||
| beginning_value: str = self._underlying_value # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should we just store the length of the underlying value? would give us a slightly smaller memory footprint.
6b190a0 to
0a4f108
Compare
psschwei
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
0a4f108 to
c873cc4
Compare
| return ( | ||
| self._underlying_value | ||
| if beginning_length is None | ||
| else self._underlying_value[beginning_length:] # type: ignore | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something I came across when testing:
beginning_length is calculated as: beginning_length = 0 if self._underlying_value is None else len(str(self._underlying_value))
This means beginning_length is always an integer (0 or positive), never None. So should we do something like:
| return ( | |
| self._underlying_value | |
| if beginning_length is None | |
| else self._underlying_value[beginning_length:] # type: ignore | |
| ) | |
| return ( | |
| self._underlying_value | |
| if beginning_length == 0 | |
| else self._underlying_value[beginning_length:] # type: ignore | |
| ) |
?
|
I generated a couple of tests to verify the behavior here: import asyncio
import pytest
from mellea.backends import ModelOption
from mellea.core import ModelOutputThunk
from mellea.stdlib.session import MelleaSession, start_session
@pytest.fixture(scope="module")
def m_session(gh_run):
"""Create a session for testing streaming behavior."""
m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 50})
yield m
del m
@pytest.mark.asyncio
async def test_astream_returns_only_new_chunks(m_session: MelleaSession):
"""Test that astream() returns only new chunks on subsequent calls, not the entire accumulated content.
This tests the fix from PR #358 where beginning_length is tracked to return
only the delta between calls.
"""
# Create a streaming output
out = m_session.instruct("Count from 1 to 10")
# First call to astream should return some initial content
first_chunk = await out.astream()
assert isinstance(first_chunk, str)
first_length = len(first_chunk)
# If not yet computed, call astream again
if not out.is_computed():
second_chunk = await out.astream()
assert isinstance(second_chunk, str)
# The second chunk should NOT include the first chunk's content
# It should only contain NEW content
# The total accumulated value should be first_chunk + second_chunk
accumulated_value = out.value
assert accumulated_value is not None
# Verify that second_chunk is only the new part
# (not the entire accumulated content)
if len(second_chunk) > 0:
# If we got new content, verify it's a substring of the accumulated value
# starting after the first chunk
assert accumulated_value.endswith(second_chunk) or second_chunk in accumulated_value
# The second chunk should be shorter than or equal to the total accumulated value
assert len(second_chunk) <= len(accumulated_value)
# The second chunk should not be identical to the full accumulated value
# (unless the first chunk was empty)
if first_length > 0:
assert second_chunk != accumulated_value
@pytest.mark.asyncio
async def test_astream_full_completion(m_session: MelleaSession):
"""Test that repeatedly calling astream() eventually returns the full completed output."""
out = m_session.instruct("Say hello")
accumulated_chunks = []
# Keep calling astream until completion
while not out.is_computed():
chunk = await out.astream()
accumulated_chunks.append(chunk)
# Get final chunk after completion
final_chunk = await out.astream()
accumulated_chunks.append(final_chunk)
# The concatenation of all chunks should equal the final value
concatenated = "".join(accumulated_chunks)
assert out.value is not None
assert concatenated == out.value
@pytest.mark.asyncio
async def test_astream_on_computed_thunk(m_session: MelleaSession):
"""Test that astream() on an already computed thunk returns the full value."""
out = m_session.instruct("Hello world")
# Wait for completion
final_value = await out.avalue()
assert out.is_computed()
# Calling astream on a computed thunk should return the full value
streamed_value = await out.astream()
assert streamed_value == final_value
@pytest.mark.asyncio
async def test_astream_empty_initial_value():
"""Test astream behavior when _underlying_value starts as None."""
# Create a thunk without initial value
thunk = ModelOutputThunk(None)
# Manually set it to computed with a value to test the edge case
thunk._underlying_value = "test content"
thunk._computed = True
# astream should return the full value when computed
result = await thunk.astream()
assert result == "test content"
@pytest.mark.asyncio
async def test_avalue_returns_full_content(m_session: MelleaSession):
"""Test that avalue() always returns the complete accumulated content."""
out = m_session.instruct("Count to 5")
# avalue should wait for completion and return full content
full_value = await out.avalue()
assert isinstance(full_value, str)
assert len(full_value) > 0
assert out.is_computed()
assert out.value == full_value
if __name__ == "__main__":
pytest.main([__file__, "-v"]) |
Misc PR
Type of PR
Description
Testing