Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,13 @@ def _ContentEmbeddingStatistics_from_vertex(
if getv(from_object, ['token_count']) is not None:
setv(to_object, ['token_count'], getv(from_object, ['token_count']))

if getv(from_object, ['tokensDetails']) is not None:
setv(
to_object,
['tokens_details'],
[item for item in getv(from_object, ['tokensDetails'])],
)

return to_object


Expand Down Expand Up @@ -1136,6 +1143,8 @@ def _EmbedContentResponse_from_vertex(
stats = {}
if usage_metadata and usage_metadata.get('promptTokenCount'):
stats['token_count'] = usage_metadata['promptTokenCount']
if usage_metadata and usage_metadata.get('promptTokensDetails'):
stats['tokensDetails'] = usage_metadata['promptTokensDetails']
if truncated:
stats['truncated'] = truncated
embedding['statistics'] = stats
Expand Down
6 changes: 6 additions & 0 deletions google/genai/tests/models/test_embed_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ def test_gemini_embedding_2_content_combination(client):
assert response.embeddings is not None
assert len(response.embeddings) == 1
assert len(response.embeddings[0].values) == 100
if client._api_client.vertexai:
statistics = response.embeddings[0].statistics
assert statistics is not None
assert statistics.token_count is not None
assert statistics.tokens_details is not None
assert len(statistics.tokens_details) > 0


@pytest.mark.asyncio
Expand Down
9 changes: 9 additions & 0 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8564,6 +8564,11 @@ class ContentEmbeddingStatistics(_common.BaseModel):
description="""Gemini Enterprise Agent Platform only. Number of tokens of the input text.
""",
)
tokens_details: Optional[list[ModalityTokenCount]] = Field(
default=None,
description="""Gemini Enterprise Agent Platform only. List of modalities and their token count for the input content.
""",
)


class ContentEmbeddingStatisticsDict(TypedDict, total=False):
Expand All @@ -8578,6 +8583,10 @@ class ContentEmbeddingStatisticsDict(TypedDict, total=False):
"""Gemini Enterprise Agent Platform only. Number of tokens of the input text.
"""

tokens_details: Optional[list[ModalityTokenCountDict]]
"""Gemini Enterprise Agent Platform only. List of modalities and their token count for the input content.
"""


ContentEmbeddingStatisticsOrDict = Union[
ContentEmbeddingStatistics, ContentEmbeddingStatisticsDict
Expand Down
Loading