1515
1616from agentex .lib import adk
1717from agentex .lib .types .acp import SendEventParams , CreateTaskParams
18- from agentex .lib .adk .models import ModelSettings
1918from agentex .lib .types .tracing import SGPTracingProcessorConfig
2019from agentex .lib .utils .logging import make_logger
2120from agentex .types .text_content import TextContent
2221from agentex .lib .utils .model_utils import BaseModel
23- from agentex .lib .core .base .run_context import RunContextWrapper
22+
23+ from agents import ModelSettings , RunContextWrapper
2424from agentex .lib .environment_variables import EnvironmentVariables
2525from agentex .lib .core .temporal .types .workflow import SignalName
2626from agentex .lib .core .temporal .workflows .workflow import BaseWorkflow
3636
3737class GuardrailFunctionOutput (BaseModel ):
3838 """Output from a guardrail function."""
39+
3940 output_info : Dict [str , Any ]
4041 tripwire_triggered : bool
4142
@@ -99,10 +100,7 @@ async def calculator(context: RunContextWrapper, args: str) -> str: # noqa: ARG
99100 b = parsed_args .get ("b" )
100101
101102 if operation is None or a is None or b is None :
102- return (
103- "Error: Missing required parameters. "
104- "Please provide 'operation', 'a', and 'b'."
105- )
103+ return "Error: Missing required parameters. Please provide 'operation', 'a', and 'b'."
106104
107105 # Convert to numbers
108106 try :
@@ -124,10 +122,7 @@ async def calculator(context: RunContextWrapper, args: str) -> str: # noqa: ARG
124122 result = a / b
125123 else :
126124 supported_ops = "add, subtract, multiply, divide"
127- return (
128- f"Error: Unknown operation '{ operation } '. "
129- f"Supported operations: { supported_ops } ."
130- )
125+ return f"Error: Unknown operation '{ operation } '. Supported operations: { supported_ops } ."
131126
132127 # Format the result nicely
133128 if result == int (result ):
@@ -160,9 +155,7 @@ async def calculator(context: RunContextWrapper, args: str) -> str: # noqa: ARG
160155
161156# Define the spaghetti guardrail function
162157async def check_spaghetti_guardrail (
163- ctx : RunContextWrapper [None ],
164- agent : Agent ,
165- input : str | list
158+ ctx : RunContextWrapper [None ], agent : Agent , input : str | list
166159) -> GuardrailFunctionOutput :
167160 """
168161 A simple guardrail that checks if 'spaghetti' is mentioned in the input.
@@ -185,25 +178,22 @@ async def check_spaghetti_guardrail(
185178 return GuardrailFunctionOutput (
186179 output_info = {
187180 "contains_spaghetti" : contains_spaghetti ,
188- "checked_text" : (
189- input_text [:200 ] + "..."
190- if len (input_text ) > 200 else input_text
191- ),
181+ "checked_text" : (input_text [:200 ] + "..." if len (input_text ) > 200 else input_text ),
192182 "rejection_message" : (
193183 "I'm sorry, but I cannot process messages about spaghetti. "
194184 "This guardrail was put in place for demonstration purposes. "
195185 "Please ask me about something else!"
196- ) if contains_spaghetti else None
186+ )
187+ if contains_spaghetti
188+ else None ,
197189 },
198- tripwire_triggered = contains_spaghetti
190+ tripwire_triggered = contains_spaghetti ,
199191 )
200192
201193
202194# Define soup input guardrail function
203195async def check_soup_guardrail (
204- ctx : RunContextWrapper [None ],
205- agent : Agent ,
206- input : str | list
196+ ctx : RunContextWrapper [None ], agent : Agent , input : str | list
207197) -> GuardrailFunctionOutput :
208198 """
209199 A guardrail that checks if 'soup' is mentioned in the input.
@@ -226,107 +216,88 @@ async def check_soup_guardrail(
226216 return GuardrailFunctionOutput (
227217 output_info = {
228218 "contains_soup" : contains_soup ,
229- "checked_text" : (
230- input_text [:200 ] + "..."
231- if len (input_text ) > 200 else input_text
232- ),
219+ "checked_text" : (input_text [:200 ] + "..." if len (input_text ) > 200 else input_text ),
233220 "rejection_message" : (
234221 "I'm sorry, but I cannot process messages about soup. "
235222 "This is a demonstration guardrail for testing purposes. "
236223 "Please ask about something other than soup!"
237- ) if contains_soup else None
224+ )
225+ if contains_soup
226+ else None ,
238227 },
239- tripwire_triggered = contains_soup
228+ tripwire_triggered = contains_soup ,
240229 )
241230
242231
243232# Create the input guardrails
244- SPAGHETTI_GUARDRAIL = TemporalInputGuardrail (
245- guardrail_function = check_spaghetti_guardrail ,
246- name = "spaghetti_guardrail"
247- )
233+ SPAGHETTI_GUARDRAIL = TemporalInputGuardrail (guardrail_function = check_spaghetti_guardrail , name = "spaghetti_guardrail" )
248234
249- SOUP_GUARDRAIL = TemporalInputGuardrail (
250- guardrail_function = check_soup_guardrail ,
251- name = "soup_guardrail"
252- )
235+ SOUP_GUARDRAIL = TemporalInputGuardrail (guardrail_function = check_soup_guardrail , name = "soup_guardrail" )
253236
254237
255238# Define pizza output guardrail function
256- async def check_pizza_guardrail (
257- ctx : RunContextWrapper [None ],
258- agent : Agent ,
259- output : str
260- ) -> GuardrailFunctionOutput :
239+ async def check_pizza_guardrail (ctx : RunContextWrapper [None ], agent : Agent , output : str ) -> GuardrailFunctionOutput :
261240 """
262241 An output guardrail that prevents mentioning pizza.
263242 """
264243 output_text = output .lower () if isinstance (output , str ) else ""
265244 contains_pizza = "pizza" in output_text
266-
245+
267246 return GuardrailFunctionOutput (
268247 output_info = {
269248 "contains_pizza" : contains_pizza ,
270249 "rejection_message" : (
271250 "I cannot provide this response as it mentions pizza. "
272251 "Due to content policies, I need to avoid discussing pizza. "
273252 "Let me provide a different response."
274- ) if contains_pizza else None
253+ )
254+ if contains_pizza
255+ else None ,
275256 },
276- tripwire_triggered = contains_pizza
257+ tripwire_triggered = contains_pizza ,
277258 )
278259
279260
280261# Define sushi output guardrail function
281- async def check_sushi_guardrail (
282- ctx : RunContextWrapper [None ],
283- agent : Agent ,
284- output : str
285- ) -> GuardrailFunctionOutput :
262+ async def check_sushi_guardrail (ctx : RunContextWrapper [None ], agent : Agent , output : str ) -> GuardrailFunctionOutput :
286263 """
287264 An output guardrail that prevents mentioning sushi.
288265 """
289266 output_text = output .lower () if isinstance (output , str ) else ""
290267 contains_sushi = "sushi" in output_text
291-
268+
292269 return GuardrailFunctionOutput (
293270 output_info = {
294271 "contains_sushi" : contains_sushi ,
295272 "rejection_message" : (
296273 "I cannot mention sushi in my response. "
297274 "This guardrail prevents discussions about sushi for demonstration purposes. "
298275 "Please let me provide information about other topics."
299- ) if contains_sushi else None
276+ )
277+ if contains_sushi
278+ else None ,
300279 },
301- tripwire_triggered = contains_sushi
280+ tripwire_triggered = contains_sushi ,
302281 )
303282
304283
305284# Create the output guardrails
306- PIZZA_GUARDRAIL = TemporalOutputGuardrail (
307- guardrail_function = check_pizza_guardrail ,
308- name = "pizza_guardrail"
309- )
285+ PIZZA_GUARDRAIL = TemporalOutputGuardrail (guardrail_function = check_pizza_guardrail , name = "pizza_guardrail" )
310286
311- SUSHI_GUARDRAIL = TemporalOutputGuardrail (
312- guardrail_function = check_sushi_guardrail ,
313- name = "sushi_guardrail"
314- )
287+ SUSHI_GUARDRAIL = TemporalOutputGuardrail (guardrail_function = check_sushi_guardrail , name = "sushi_guardrail" )
315288
316289
317290# Example output guardrail function (kept for reference)
318291async def check_output_length_guardrail (
319- ctx : RunContextWrapper [None ],
320- agent : Agent ,
321- output : str
292+ ctx : RunContextWrapper [None ], agent : Agent , output : str
322293) -> GuardrailFunctionOutput :
323294 """
324295 A simple output guardrail that checks if the response is too long.
325296 """
326297 # Check the length of the output
327298 max_length = 1000 # Maximum allowed characters
328299 is_too_long = len (output ) > max_length if isinstance (output , str ) else False
329-
300+
330301 return GuardrailFunctionOutput (
331302 output_info = {
332303 "output_length" : len (output ) if isinstance (output , str ) else 0 ,
@@ -336,9 +307,11 @@ async def check_output_length_guardrail(
336307 f"I'm sorry, but my response is too long ({ len (output )} characters). "
337308 f"Please ask a more specific question so I can provide a concise answer "
338309 f"(max { max_length } characters)."
339- ) if is_too_long else None
310+ )
311+ if is_too_long
312+ else None ,
340313 },
341- tripwire_triggered = is_too_long
314+ tripwire_triggered = is_too_long ,
342315 )
343316
344317
@@ -353,10 +326,7 @@ async def check_output_length_guardrail(
353326# Create the calculator tool
354327CALCULATOR_TOOL = FunctionTool (
355328 name = "calculator" ,
356- description = (
357- "Performs basic arithmetic operations (add, subtract, multiply, "
358- "divide) on two numbers."
359- ),
329+ description = ("Performs basic arithmetic operations (add, subtract, multiply, divide) on two numbers." ),
360330 params_json_schema = {
361331 "type" : "object" ,
362332 "properties" : {
@@ -390,26 +360,21 @@ def __init__(self):
390360 @workflow .signal (name = SignalName .RECEIVE_EVENT )
391361 @override
392362 async def on_task_event_send (self , params : SendEventParams ) -> None :
393-
394363 if not params .event .content :
395364 return
396365 if params .event .content .type != "text" :
397366 raise ValueError (f"Expected text message, got { params .event .content .type } " )
398367
399368 if params .event .content .author != "user" :
400- raise ValueError (
401- f"Expected user message, got { params .event .content .author } "
402- )
369+ raise ValueError (f"Expected user message, got { params .event .content .author } " )
403370
404371 if self ._state is None :
405372 raise ValueError ("State is not initialized" )
406373
407374 # Increment the turn number
408375 self ._state .turn_number += 1
409376 # Add the new user message to the message history
410- self ._state .input_list .append (
411- {"role" : "user" , "content" : params .event .content .content }
412- )
377+ self ._state .input_list .append ({"role" : "user" , "content" : params .event .content .content })
413378
414379 async with adk .tracing .span (
415380 trace_id = params .task .id ,
@@ -475,7 +440,7 @@ async def on_task_event_send(self, params: SendEventParams) -> None:
475440 input_guardrails = [SPAGHETTI_GUARDRAIL , SOUP_GUARDRAIL ],
476441 output_guardrails = [PIZZA_GUARDRAIL , SUSHI_GUARDRAIL ],
477442 )
478-
443+
479444 # Update state with the final input list from result
480445 if self ._state and result :
481446 final_list = getattr (result , "final_input_list" , None )
0 commit comments