Skip to content

Commit 14a32a6

Browse files
committed
fixing the temporal tests
1 parent bae85cd commit 14a32a6

File tree

7 files changed

+61
-87
lines changed

7 files changed

+61
-87
lines changed

.github/workflows/agentex-tutorials-test.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,20 @@ jobs:
5555
- name: Pull latest AgentEx image
5656
run: |
5757
echo "🐳 Pulling latest Scale AgentEx Docker image..."
58-
docker pull ghcr.io/scaleapi/scale-agentex/agentex:latest
59-
echo "✅ Successfully pulled AgentEx Docker image"
58+
max_attempts=3
59+
attempt=1
60+
while [ $attempt -le $max_attempts ]; do
61+
echo "Attempt $attempt of $max_attempts..."
62+
if docker pull ghcr.io/scaleapi/scale-agentex/agentex:latest; then
63+
echo "✅ Successfully pulled AgentEx Docker image"
64+
exit 0
65+
fi
66+
echo "❌ Pull failed, waiting before retry..."
67+
sleep $((attempt * 10))
68+
attempt=$((attempt + 1))
69+
done
70+
echo "❌ Failed to pull image after $max_attempts attempts"
71+
exit 1
6072
6173
- name: Checkout scale-agentex repo
6274
uses: actions/checkout@v4

examples/tutorials/10_async/00_base/000_hello_acp/tests/test_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ async def poll_for_response() -> None:
112112
):
113113
assert isinstance(message, TaskMessage)
114114
if message.content and message.content.type == "text" and message.content.author == "agent":
115-
assert "Hello! I've received your message" in message.content.content
116-
agent_response_found = True
117-
break
115+
if "Hello! I've received your message" in message.content.content:
116+
agent_response_found = True
117+
break
118118

119119
try:
120120
await asyncio.wait_for(poll_for_response(), timeout=30)

examples/tutorials/10_async/10_temporal/050_agent_chat_guardrails/project/workflow.py

Lines changed: 43 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
from agentex.lib import adk
1717
from agentex.lib.types.acp import SendEventParams, CreateTaskParams
18-
from agentex.lib.adk.models import ModelSettings
1918
from agentex.lib.types.tracing import SGPTracingProcessorConfig
2019
from agentex.lib.utils.logging import make_logger
2120
from agentex.types.text_content import TextContent
2221
from agentex.lib.utils.model_utils import BaseModel
23-
from agentex.lib.core.base.run_context import RunContextWrapper
22+
23+
from agents import ModelSettings, RunContextWrapper
2424
from agentex.lib.environment_variables import EnvironmentVariables
2525
from agentex.lib.core.temporal.types.workflow import SignalName
2626
from agentex.lib.core.temporal.workflows.workflow import BaseWorkflow
@@ -36,6 +36,7 @@
3636

3737
class GuardrailFunctionOutput(BaseModel):
3838
"""Output from a guardrail function."""
39+
3940
output_info: Dict[str, Any]
4041
tripwire_triggered: bool
4142

@@ -99,10 +100,7 @@ async def calculator(context: RunContextWrapper, args: str) -> str: # noqa: ARG
99100
b = parsed_args.get("b")
100101

101102
if operation is None or a is None or b is None:
102-
return (
103-
"Error: Missing required parameters. "
104-
"Please provide 'operation', 'a', and 'b'."
105-
)
103+
return "Error: Missing required parameters. Please provide 'operation', 'a', and 'b'."
106104

107105
# Convert to numbers
108106
try:
@@ -124,10 +122,7 @@ async def calculator(context: RunContextWrapper, args: str) -> str: # noqa: ARG
124122
result = a / b
125123
else:
126124
supported_ops = "add, subtract, multiply, divide"
127-
return (
128-
f"Error: Unknown operation '{operation}'. "
129-
f"Supported operations: {supported_ops}."
130-
)
125+
return f"Error: Unknown operation '{operation}'. Supported operations: {supported_ops}."
131126

132127
# Format the result nicely
133128
if result == int(result):
@@ -160,9 +155,7 @@ async def calculator(context: RunContextWrapper, args: str) -> str: # noqa: ARG
160155

161156
# Define the spaghetti guardrail function
162157
async def check_spaghetti_guardrail(
163-
ctx: RunContextWrapper[None],
164-
agent: Agent,
165-
input: str | list
158+
ctx: RunContextWrapper[None], agent: Agent, input: str | list
166159
) -> GuardrailFunctionOutput:
167160
"""
168161
A simple guardrail that checks if 'spaghetti' is mentioned in the input.
@@ -185,25 +178,22 @@ async def check_spaghetti_guardrail(
185178
return GuardrailFunctionOutput(
186179
output_info={
187180
"contains_spaghetti": contains_spaghetti,
188-
"checked_text": (
189-
input_text[:200] + "..."
190-
if len(input_text) > 200 else input_text
191-
),
181+
"checked_text": (input_text[:200] + "..." if len(input_text) > 200 else input_text),
192182
"rejection_message": (
193183
"I'm sorry, but I cannot process messages about spaghetti. "
194184
"This guardrail was put in place for demonstration purposes. "
195185
"Please ask me about something else!"
196-
) if contains_spaghetti else None
186+
)
187+
if contains_spaghetti
188+
else None,
197189
},
198-
tripwire_triggered=contains_spaghetti
190+
tripwire_triggered=contains_spaghetti,
199191
)
200192

201193

202194
# Define soup input guardrail function
203195
async def check_soup_guardrail(
204-
ctx: RunContextWrapper[None],
205-
agent: Agent,
206-
input: str | list
196+
ctx: RunContextWrapper[None], agent: Agent, input: str | list
207197
) -> GuardrailFunctionOutput:
208198
"""
209199
A guardrail that checks if 'soup' is mentioned in the input.
@@ -226,107 +216,88 @@ async def check_soup_guardrail(
226216
return GuardrailFunctionOutput(
227217
output_info={
228218
"contains_soup": contains_soup,
229-
"checked_text": (
230-
input_text[:200] + "..."
231-
if len(input_text) > 200 else input_text
232-
),
219+
"checked_text": (input_text[:200] + "..." if len(input_text) > 200 else input_text),
233220
"rejection_message": (
234221
"I'm sorry, but I cannot process messages about soup. "
235222
"This is a demonstration guardrail for testing purposes. "
236223
"Please ask about something other than soup!"
237-
) if contains_soup else None
224+
)
225+
if contains_soup
226+
else None,
238227
},
239-
tripwire_triggered=contains_soup
228+
tripwire_triggered=contains_soup,
240229
)
241230

242231

243232
# Create the input guardrails
244-
SPAGHETTI_GUARDRAIL = TemporalInputGuardrail(
245-
guardrail_function=check_spaghetti_guardrail,
246-
name="spaghetti_guardrail"
247-
)
233+
SPAGHETTI_GUARDRAIL = TemporalInputGuardrail(guardrail_function=check_spaghetti_guardrail, name="spaghetti_guardrail")
248234

249-
SOUP_GUARDRAIL = TemporalInputGuardrail(
250-
guardrail_function=check_soup_guardrail,
251-
name="soup_guardrail"
252-
)
235+
SOUP_GUARDRAIL = TemporalInputGuardrail(guardrail_function=check_soup_guardrail, name="soup_guardrail")
253236

254237

255238
# Define pizza output guardrail function
256-
async def check_pizza_guardrail(
257-
ctx: RunContextWrapper[None],
258-
agent: Agent,
259-
output: str
260-
) -> GuardrailFunctionOutput:
239+
async def check_pizza_guardrail(ctx: RunContextWrapper[None], agent: Agent, output: str) -> GuardrailFunctionOutput:
261240
"""
262241
An output guardrail that prevents mentioning pizza.
263242
"""
264243
output_text = output.lower() if isinstance(output, str) else ""
265244
contains_pizza = "pizza" in output_text
266-
245+
267246
return GuardrailFunctionOutput(
268247
output_info={
269248
"contains_pizza": contains_pizza,
270249
"rejection_message": (
271250
"I cannot provide this response as it mentions pizza. "
272251
"Due to content policies, I need to avoid discussing pizza. "
273252
"Let me provide a different response."
274-
) if contains_pizza else None
253+
)
254+
if contains_pizza
255+
else None,
275256
},
276-
tripwire_triggered=contains_pizza
257+
tripwire_triggered=contains_pizza,
277258
)
278259

279260

280261
# Define sushi output guardrail function
281-
async def check_sushi_guardrail(
282-
ctx: RunContextWrapper[None],
283-
agent: Agent,
284-
output: str
285-
) -> GuardrailFunctionOutput:
262+
async def check_sushi_guardrail(ctx: RunContextWrapper[None], agent: Agent, output: str) -> GuardrailFunctionOutput:
286263
"""
287264
An output guardrail that prevents mentioning sushi.
288265
"""
289266
output_text = output.lower() if isinstance(output, str) else ""
290267
contains_sushi = "sushi" in output_text
291-
268+
292269
return GuardrailFunctionOutput(
293270
output_info={
294271
"contains_sushi": contains_sushi,
295272
"rejection_message": (
296273
"I cannot mention sushi in my response. "
297274
"This guardrail prevents discussions about sushi for demonstration purposes. "
298275
"Please let me provide information about other topics."
299-
) if contains_sushi else None
276+
)
277+
if contains_sushi
278+
else None,
300279
},
301-
tripwire_triggered=contains_sushi
280+
tripwire_triggered=contains_sushi,
302281
)
303282

304283

305284
# Create the output guardrails
306-
PIZZA_GUARDRAIL = TemporalOutputGuardrail(
307-
guardrail_function=check_pizza_guardrail,
308-
name="pizza_guardrail"
309-
)
285+
PIZZA_GUARDRAIL = TemporalOutputGuardrail(guardrail_function=check_pizza_guardrail, name="pizza_guardrail")
310286

311-
SUSHI_GUARDRAIL = TemporalOutputGuardrail(
312-
guardrail_function=check_sushi_guardrail,
313-
name="sushi_guardrail"
314-
)
287+
SUSHI_GUARDRAIL = TemporalOutputGuardrail(guardrail_function=check_sushi_guardrail, name="sushi_guardrail")
315288

316289

317290
# Example output guardrail function (kept for reference)
318291
async def check_output_length_guardrail(
319-
ctx: RunContextWrapper[None],
320-
agent: Agent,
321-
output: str
292+
ctx: RunContextWrapper[None], agent: Agent, output: str
322293
) -> GuardrailFunctionOutput:
323294
"""
324295
A simple output guardrail that checks if the response is too long.
325296
"""
326297
# Check the length of the output
327298
max_length = 1000 # Maximum allowed characters
328299
is_too_long = len(output) > max_length if isinstance(output, str) else False
329-
300+
330301
return GuardrailFunctionOutput(
331302
output_info={
332303
"output_length": len(output) if isinstance(output, str) else 0,
@@ -336,9 +307,11 @@ async def check_output_length_guardrail(
336307
f"I'm sorry, but my response is too long ({len(output)} characters). "
337308
f"Please ask a more specific question so I can provide a concise answer "
338309
f"(max {max_length} characters)."
339-
) if is_too_long else None
310+
)
311+
if is_too_long
312+
else None,
340313
},
341-
tripwire_triggered=is_too_long
314+
tripwire_triggered=is_too_long,
342315
)
343316

344317

@@ -353,10 +326,7 @@ async def check_output_length_guardrail(
353326
# Create the calculator tool
354327
CALCULATOR_TOOL = FunctionTool(
355328
name="calculator",
356-
description=(
357-
"Performs basic arithmetic operations (add, subtract, multiply, "
358-
"divide) on two numbers."
359-
),
329+
description=("Performs basic arithmetic operations (add, subtract, multiply, divide) on two numbers."),
360330
params_json_schema={
361331
"type": "object",
362332
"properties": {
@@ -390,26 +360,21 @@ def __init__(self):
390360
@workflow.signal(name=SignalName.RECEIVE_EVENT)
391361
@override
392362
async def on_task_event_send(self, params: SendEventParams) -> None:
393-
394363
if not params.event.content:
395364
return
396365
if params.event.content.type != "text":
397366
raise ValueError(f"Expected text message, got {params.event.content.type}")
398367

399368
if params.event.content.author != "user":
400-
raise ValueError(
401-
f"Expected user message, got {params.event.content.author}"
402-
)
369+
raise ValueError(f"Expected user message, got {params.event.content.author}")
403370

404371
if self._state is None:
405372
raise ValueError("State is not initialized")
406373

407374
# Increment the turn number
408375
self._state.turn_number += 1
409376
# Add the new user message to the message history
410-
self._state.input_list.append(
411-
{"role": "user", "content": params.event.content.content}
412-
)
377+
self._state.input_list.append({"role": "user", "content": params.event.content.content})
413378

414379
async with adk.tracing.span(
415380
trace_id=params.task.id,
@@ -475,7 +440,7 @@ async def on_task_event_send(self, params: SendEventParams) -> None:
475440
input_guardrails=[SPAGHETTI_GUARDRAIL, SOUP_GUARDRAIL],
476441
output_guardrails=[PIZZA_GUARDRAIL, SUSHI_GUARDRAIL],
477442
)
478-
443+
479444
# Update state with the final input list from result
480445
if self._state and result:
481446
final_list = getattr(result, "final_input_list", None)

examples/tutorials/10_async/10_temporal/060_open_ai_agents_sdk_hello_world/manifest.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ agent:
104104
# Optional: Set Environment variables for running your agent locally as well
105105
# as for deployment later on
106106
env:
107-
OPENAI_API_KEY: ""
108107
# OPENAI_BASE_URL: "<YOUR_OPENAI_BASE_URL_HERE>"
109108
OPENAI_ORG_ID: ""
110109

examples/tutorials/10_async/10_temporal/070_open_ai_agents_sdk_tools/manifest.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ agent:
102102
# Optional: Set Environment variables for running your agent locally as well
103103
# as for deployment later on
104104
env:
105-
OPENAI_API_KEY: ""
106105
# OPENAI_BASE_URL: "<YOUR_OPENAI_BASE_URL_HERE>"
107106
OPENAI_ORG_ID: ""
108107

examples/tutorials/10_async/10_temporal/080_open_ai_agents_sdk_human_in_the_loop/manifest.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ agent:
104104
# Optional: Set Environment variables for running your agent locally as well
105105
# as for deployment later on
106106
env:
107-
OPENAI_API_KEY: ""
108107
# OPENAI_BASE_URL: "<YOUR_OPENAI_BASE_URL_HERE>"
109108
OPENAI_ORG_ID: ""
110109

examples/tutorials/10_async/10_temporal/090_claude_agents_sdk_mvp/tests/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# Configuration from environment variables
2020
AGENTEX_API_BASE_URL = os.environ.get("AGENTEX_API_BASE_URL", "http://localhost:5003")
21-
AGENT_NAME = os.environ.get("AGENT_NAME", "claude_")
21+
AGENT_NAME = os.environ.get("AGENT_NAME", "claude-mvp-agent")
2222

2323

2424
@pytest_asyncio.fixture

0 commit comments

Comments
 (0)