From 2e15c5c8799e6db4309fe71ad9dcdf81e084a73e Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 21 Nov 2025 14:07:35 -0500 Subject: [PATCH 1/4] Reimported Constructor Refactor --- src/inline/plugin.py | 591 ++++++++++++++++++++++++++++++------------- 1 file changed, 412 insertions(+), 179 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 11c0774..d331cf3 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -171,6 +171,7 @@ def __init__(self): self.tag = [] self.disabled = False self.timeout = -1.0 + self.devices = None self.globs = {} def to_test(self): @@ -293,6 +294,8 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" + arg_devices_str = "devices" + diff_test_str = "diff_test" assume = "assume" inline_module_imported = False @@ -362,186 +365,44 @@ def parse_constructor(self, node): """ Parse a constructor call. """ - NUM_OF_ARGUMENTS = 6 + + # Argument Order: + # 0) test_name (str) + # 1) parameterized (bool) + # 2) repeated (positive integer) + # 3) tag (str) + # 4) disabled (bool) + # 5) timeout (positive float) + # 6) devices (str array) + + + + keyword_idxs = { + self.arg_test_name_str : 0, + self.arg_parameterized_str : 1, + self.arg_repeated_str : 2, + self.arg_tag_str : 3, + self.arg_disabled_str : 4, + self.arg_timeout_str : 5, + self.arg_devices_str : 6 + } + + NUM_OF_ARGUMENTS = 7 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments - if sys.version_info >= (3, 8, 0): - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Constant) and isinstance(arg.value, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.value - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Constant) and isinstance(arg.value, int): - if arg.value <= 0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.value - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) - self.cur_inline_test.tag = tags - elif index == 4 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - elif ( - index == 5 - and isinstance(arg, ast.Constant) - and (isinstance(arg.value, float) or isinstance(arg.value, int)) - ): - self.cur_inline_test.timeout = arg.value - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, str) - ): - self.cur_inline_test.test_name = keyword.value.value - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, int) - ): - if keyword.value.value <= 0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.value - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): - tags = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Constant) - and (isinstance(keyword.value.value, float) or isinstance(keyword.value.value, int)) - ): - if keyword.value.value <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.value - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Str) and isinstance(arg.s, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.s - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Num) and isinstance(arg.n, int): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.n - # check if "tag" is a list of string - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif index == 4 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - # check if "timeout" is a positive int - elif ( - index == 5 and isinstance(arg, ast.Num) and (isinstance(arg.n, float) or isinstance(arg.n, int)) - ): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = arg.n - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive intege, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Str) - and isinstance(keyword.value.s, str) - ): - self.cur_inline_test.test_name = keyword.value.s - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Num) - and isinstance(keyword.value.n, int) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.n - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): - tags = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Num) - and (isinstance(keyword.value.n, float) or isinstance(keyword.value.n, int)) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.n - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - raise MalformedException(f"inline test: invalid {self.class_name_str}(), expected at most 3 args") + self.parse_constructor_args(node.args) + + #keyword arguments + keyword_args = [] + + #create list with 7 null values (for each position) + for i in range(0, NUM_OF_ARGUMENTS): + keyword_args.append(None) + + for keyword in node.keywords: + keyword_args[keyword_idxs[keyword.arg]] = keyword.value + self.parse_constructor_args(keyword_args) + if not self.cur_inline_test.test_name: # by default, use lineno as test name @@ -549,6 +410,149 @@ def parse_constructor(self, node): # set the line number self.cur_inline_test.lineno = node.lineno + def parse_constructor_args(self, args): + class ConstrArgs(enum.Enum): + TEST_NAME = 0 + PARAMETERIZED = 1 + REPEATED = 2 + TAG_STR = 3 + DISABLED = 4 + TIMEOUT = 5 + DEVICES = 6 + + property_names = { + ConstrArgs.TEST_NAME : "test_name", + ConstrArgs.PARAMETERIZED : "parameterized", + ConstrArgs.REPEATED : "repeated", + ConstrArgs.TAG_STR : "tag", + ConstrArgs.DISABLED : "disabled", + ConstrArgs.TIMEOUT : "timeout", + ConstrArgs.DEVICES : "devices" + } + + pre_38_val_names = { + ConstrArgs.TEST_NAME : "s", + ConstrArgs.PARAMETERIZED : "value", + ConstrArgs.REPEATED : "n", + ConstrArgs.TAG_STR : "s", + ConstrArgs.DISABLED : "value", + ConstrArgs.TIMEOUT : "n", + ConstrArgs.DEVICES : "" + } + + pre_38_expec_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Str, + ConstrArgs.PARAMETERIZED : ast.NameConstant, + ConstrArgs.REPEATED : ast.Num, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.NameConstant, + ConstrArgs.TIMEOUT : ast.Num, + } + + expected_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Constant, + ConstrArgs.PARAMETERIZED : ast.Constant, + ConstrArgs.REPEATED : ast.Constant, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.Constant, + ConstrArgs.TIMEOUT : ast.Constant + } + + expected_ast_val_args = { + ConstrArgs.TEST_NAME : [str], + ConstrArgs.PARAMETERIZED : [bool], + ConstrArgs.REPEATED : [int], + ConstrArgs.TAG_STR : [None], + ConstrArgs.DISABLED : [bool], + ConstrArgs.TIMEOUT : [float, int], + ConstrArgs.DEVICES : [str] + } + + NUM_OF_ARGUMENTS = 7 + + # Arguments organized by expected ast type, value type, and index in that order + for index, arg in enumerate(args): + # Skips over null arguments; needed for keywords + if arg == None: + continue + + # Devices are not referenced in versions before 3.8; all other arguments can be from any version + if index == ConstrArgs.DEVICES and isinstance(arg, ast.List): + devices = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException("devices can only be List of string") + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + self.cur_inline_test.devices = devices + # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8 + else: + corr_arg_type = False + corr_val_type = False + value_prop_name = "" + arg_idx = ConstrArgs(index) + + if sys.version_info >= (3, 8, 0) and isinstance(arg, expected_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = "value" + elif sys.version_info < (3, 8, 0) and isinstance(arg, pre_38_expec_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = pre_38_val_names[arg_idx] + + # Verifies value types; skipped for ast node types with no nested values + for arg_type in expected_ast_val_args[arg_idx]: + if arg_type == None: + corr_val_type = True + break + if isinstance(arg.value, arg_type): + corr_val_type = True + break + + if corr_val_type and corr_arg_type: + # Accounts for additional checks for REPEATED and TAG_STR arguments + if arg_idx == ConstrArgs.REPEATED: + if arg.value <= 0: + raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") + self.cur_inline_test.repeated = getattr(arg, value_prop_name) + elif arg_idx == ConstrArgs.TAG_STR: + tags = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException(f"tag can only be List of string") + tags.append(getattr(elt, value_prop_name)) + self.cur_inline_test.tag = tags + # For non-special cases, set the attribute defined by the dictionary + else: + setattr(self.cur_inline_test, + property_names[arg_idx], + getattr(arg, value_prop_name)) + + + ## Match implementation of above conditional tree; commented since Python < 3.10 does not support match + + # match arg_idx: + # case ConstrArgs.REPEATED: + # if arg.value <= 0: + # raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") + # self.cur_inline_test.repeated = getattr(arg, value_prop_name) + # case ConstrArgs.TAG_STR: + # tags = [] + # for elt in arg.elts: + # if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + # raise MalformedException(f"tag can only be List of string") + # tags.append(getattr(elt, value_prop_name)) + # self.cur_inline_test.tag = tags + # # For non-special cases, set the attribute defined by the dictionary + # case _: + # setattr(self.cur_inline_test, + # property_names[arg_idx], + # getattr(arg, value_prop_name)) + else: + raise MalformedException( + f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" + ) + def parameterized_inline_tests_init(self, node: ast.List): if not self.cur_inline_test.parameterized_inline_tests: self.cur_inline_test.parameterized_inline_tests = [InlineTest() for _ in range(len(node.elts))] @@ -885,6 +889,231 @@ def parse_check_not_same(self, node): self.cur_inline_test.check_stmts.append(assert_node) else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") + + def parse_diff_test(self, node): + if not self.cur_inline_test.devices: + raise MalformedException("diff_test can only be used with the 'devices' parameter.") + + if len(node.args) != 1: + raise MalformedException("diff_test() requires exactly 1 argument.") + + output_node = self.parse_group(node.args[0]) + + # Get the original operation + original_op = None + for stmt in self.cur_inline_test.previous_stmts: + if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: + original_op = stmt.value + break + + if not original_op: + raise MalformedException("Could not find original operation for diff_test") + + # Create our new statements + new_statements = [] + device_outputs = [] + + # Import necessary modules for seed setting - Always add these + # Import random module + import_random = ast.ImportFrom( + module='random', + names=[ast.alias(name='seed', asname=None)], + level=0 + ) + new_statements.append(import_random) + + # Import numpy.random + import_np = ast.ImportFrom( + module='numpy', + names=[ast.alias(name='random', asname='np_random')], + level=0 + ) + new_statements.append(import_np) + + # Create seed function - Always add this + seed_func_def = ast.FunctionDef( + name='set_random_seed', + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg='seed_value', annotation=None)], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Name(id='seed', ctx=ast.Load()), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='manual_seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='np_random', ctx=ast.Load()), + attr='seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ) + ], + decorator_list=[], + returns=None + ) + new_statements.append(seed_func_def) + + # Process input tensors + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + ref_var = f"{input_var}_ref" + + # Always clone inputs for in-place operations + new_statements.append( + ast.Assign( + targets=[ast.Name(id=ref_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=given_stmt.value, + attr="clone" + ), + args=[], + keywords=[] + ) + ) + ) + + # Create device-specific versions + for device in self.cur_inline_test.devices: + device_var = f"{input_var}_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=ref_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + ) + + # Create device-specific operations + device_input_map = {device: {} for device in self.cur_inline_test.devices} + for device in self.cur_inline_test.devices: + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + device_input_map[device][input_var] = f"{input_var}_{device}" + + # Always set seed before each device operation - no condition check + new_statements.append( + ast.Expr( + value=ast.Call( + func=ast.Name(id='set_random_seed', ctx=ast.Load()), + args=[ast.Constant(value=42)], # Use constant seed 42 + keywords=[] + ) + ) + ) + + device_op = copy.deepcopy(original_op) + + # Replace input references + class ReplaceInputs(ast.NodeTransformer): + def visit_Name(self, node): + if node.id in device_input_map[device]: + return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) + return node + + device_op = ReplaceInputs().visit(device_op) + device_output = f"output_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_output, ctx=ast.Store())], + value=device_op + ) + ) + device_outputs.append(device_output) + + # Standard comparison method for all operations - no condition check + comparisons = [] + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] + ), + ast.Constant(value=True) + ) + comparisons.append(comparison) + + # Replace statements + self.cur_inline_test.previous_stmts = new_statements + self.cur_inline_test.check_stmts = comparisons + + + def build_fail(self): equal_node = ast.Compare( @@ -986,11 +1215,13 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) + elif call.func.attr == self.diff_test_str: + self.parse_diff_test(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") @@ -1131,6 +1362,8 @@ def _find(self, tests, obj, module, globs, seen): ###################################################################### class InlineTestRunner: def run(self, test: InlineTest, out: List) -> None: + test_str = test.to_test() + print(test_str) tree = ast.parse(test.to_test()) codeobj = compile(tree, filename="", mode="exec") start_time = time.time() From 44bedd8658dc41a432190ea271775a81ba8ac04e Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 21 Nov 2025 15:21:39 -0500 Subject: [PATCH 2/4] Fixed Edge Cases for < 3.7 Repeated and Tag Str Parameters --- src/inline/plugin.py | 18 +++++++++++++----- tests/test_plugin.py | 4 ++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index d331cf3..a153ca2 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -505,22 +505,30 @@ class ConstrArgs(enum.Enum): if arg_type == None: corr_val_type = True break - if isinstance(arg.value, arg_type): + if isinstance(getattr(arg, value_prop_name), arg_type): corr_val_type = True break if corr_val_type and corr_arg_type: # Accounts for additional checks for REPEATED and TAG_STR arguments if arg_idx == ConstrArgs.REPEATED: - if arg.value <= 0: + value = getattr(arg, value_prop_name) + if value <= 0: raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = getattr(arg, value_prop_name) + self.cur_inline_test.repeated = value elif arg_idx == ConstrArgs.TAG_STR: tags = [] + + if sys.version_info < (3, 8, 0): + elt_type = ast.Str + else: + elt_type = ast.Constant + for elt in arg.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + value = getattr(elt, value_prop_name) + if (not isinstance(elt, elt_type) and isinstance(value, str)): raise MalformedException(f"tag can only be List of string") - tags.append(getattr(elt, value_prop_name)) + tags.append(value) self.cur_inline_test.tag = tags # For non-special cases, set the attribute defined by the dictionary else: diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 40c3096..35bf90e 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -2,6 +2,10 @@ from _pytest.pytester import Pytester import pytest +# For testing in Spyder only +if __name__ == "__main__": + pytest.main(['-v', '-s']) + # pytest -p pytester class TestInlinetests: From 67f0a04652aeabe25546e4b95fa363c67d62d299 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 21 Nov 2025 15:22:10 -0500 Subject: [PATCH 3/4] Commented Testing Call Again --- tests/test_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 35bf90e..56dd482 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -3,8 +3,8 @@ import pytest # For testing in Spyder only -if __name__ == "__main__": - pytest.main(['-v', '-s']) +# if __name__ == "__main__": +# pytest.main(['-v', '-s']) # pytest -p pytester From f91e6847b4e399ac84cefaaa04ad9bd4eec0975e Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 12:36:06 -0500 Subject: [PATCH 4/4] Addressed Changes from Pull Request - Removed IDE Debug statements - Removed Diff_test functionality; will be added back later in future PR --- src/inline/plugin.py | 254 +------------------------------------------ tests/test_plugin.py | 4 - 2 files changed, 5 insertions(+), 253 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index a153ca2..f8ddfc1 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -294,8 +294,7 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" - arg_devices_str = "devices" - diff_test_str = "diff_test" + assume = "assume" inline_module_imported = False @@ -384,10 +383,9 @@ def parse_constructor(self, node): self.arg_tag_str : 3, self.arg_disabled_str : 4, self.arg_timeout_str : 5, - self.arg_devices_str : 6 } - NUM_OF_ARGUMENTS = 7 + NUM_OF_ARGUMENTS = 6 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments self.parse_constructor_args(node.args) @@ -418,7 +416,6 @@ class ConstrArgs(enum.Enum): TAG_STR = 3 DISABLED = 4 TIMEOUT = 5 - DEVICES = 6 property_names = { ConstrArgs.TEST_NAME : "test_name", @@ -427,7 +424,6 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : "tag", ConstrArgs.DISABLED : "disabled", ConstrArgs.TIMEOUT : "timeout", - ConstrArgs.DEVICES : "devices" } pre_38_val_names = { @@ -437,7 +433,6 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : "s", ConstrArgs.DISABLED : "value", ConstrArgs.TIMEOUT : "n", - ConstrArgs.DEVICES : "" } pre_38_expec_ast_arg_type = { @@ -465,10 +460,9 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : [None], ConstrArgs.DISABLED : [bool], ConstrArgs.TIMEOUT : [float, int], - ConstrArgs.DEVICES : [str] } - NUM_OF_ARGUMENTS = 7 + NUM_OF_ARGUMENTS = 6 # Arguments organized by expected ast type, value type, and index in that order for index, arg in enumerate(args): @@ -476,16 +470,7 @@ class ConstrArgs(enum.Enum): if arg == None: continue - # Devices are not referenced in versions before 3.8; all other arguments can be from any version - if index == ConstrArgs.DEVICES and isinstance(arg, ast.List): - devices = [] - for elt in arg.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException("devices can only be List of string") - if elt.value not in {"cpu", "cuda", "mps"}: - raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") - devices.append(elt.value) - self.cur_inline_test.devices = devices + # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8 else: corr_arg_type = False @@ -898,231 +883,6 @@ def parse_check_not_same(self, node): else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") - def parse_diff_test(self, node): - if not self.cur_inline_test.devices: - raise MalformedException("diff_test can only be used with the 'devices' parameter.") - - if len(node.args) != 1: - raise MalformedException("diff_test() requires exactly 1 argument.") - - output_node = self.parse_group(node.args[0]) - - # Get the original operation - original_op = None - for stmt in self.cur_inline_test.previous_stmts: - if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: - original_op = stmt.value - break - - if not original_op: - raise MalformedException("Could not find original operation for diff_test") - - # Create our new statements - new_statements = [] - device_outputs = [] - - # Import necessary modules for seed setting - Always add these - # Import random module - import_random = ast.ImportFrom( - module='random', - names=[ast.alias(name='seed', asname=None)], - level=0 - ) - new_statements.append(import_random) - - # Import numpy.random - import_np = ast.ImportFrom( - module='numpy', - names=[ast.alias(name='random', asname='np_random')], - level=0 - ) - new_statements.append(import_np) - - # Create seed function - Always add this - seed_func_def = ast.FunctionDef( - name='set_random_seed', - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg='seed_value', annotation=None)], - kwonlyargs=[], - kw_defaults=[], - defaults=[] - ), - body=[ - ast.Expr( - value=ast.Call( - func=ast.Name(id='seed', ctx=ast.Load()), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='torch', ctx=ast.Load()), - attr='manual_seed' - ), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='np_random', ctx=ast.Load()), - attr='seed' - ), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ) - ], - decorator_list=[], - returns=None - ) - new_statements.append(seed_func_def) - - # Process input tensors - for given_stmt in self.cur_inline_test.given_stmts: - input_var = given_stmt.targets[0].id - ref_var = f"{input_var}_ref" - - # Always clone inputs for in-place operations - new_statements.append( - ast.Assign( - targets=[ast.Name(id=ref_var, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=given_stmt.value, - attr="clone" - ), - args=[], - keywords=[] - ) - ) - ) - - # Create device-specific versions - for device in self.cur_inline_test.devices: - device_var = f"{input_var}_{device}" - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=device_var, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=ref_var, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value=device)], - keywords=[] - ) - ) - ) - - # Create device-specific operations - device_input_map = {device: {} for device in self.cur_inline_test.devices} - for device in self.cur_inline_test.devices: - for given_stmt in self.cur_inline_test.given_stmts: - input_var = given_stmt.targets[0].id - device_input_map[device][input_var] = f"{input_var}_{device}" - - # Always set seed before each device operation - no condition check - new_statements.append( - ast.Expr( - value=ast.Call( - func=ast.Name(id='set_random_seed', ctx=ast.Load()), - args=[ast.Constant(value=42)], # Use constant seed 42 - keywords=[] - ) - ) - ) - - device_op = copy.deepcopy(original_op) - - # Replace input references - class ReplaceInputs(ast.NodeTransformer): - def visit_Name(self, node): - if node.id in device_input_map[device]: - return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) - return node - - device_op = ReplaceInputs().visit(device_op) - device_output = f"output_{device}" - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=device_output, ctx=ast.Store())], - value=device_op - ) - ) - device_outputs.append(device_output) - - # Standard comparison method for all operations - no condition check - comparisons = [] - for i in range(len(device_outputs) - 1): - dev1 = device_outputs[i] - dev2 = device_outputs[i + 1] - - dev1_cpu = f"{dev1}_cpu" - dev2_cpu = f"{dev2}_cpu" - - # Move outputs back to CPU for comparison - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev1, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) - ) - ) - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev2, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) - ) - ) - - # Standard allclose comparison - comparison = self.build_assert_eq( - ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev1_cpu, ctx=ast.Load()), - attr="allclose" - ), - args=[ - ast.Name(id=dev2_cpu, ctx=ast.Load()) - ], - keywords=[ - ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) - ] - ), - ast.Constant(value=True) - ) - comparisons.append(comparison) - - # Replace statements - self.cur_inline_test.previous_stmts = new_statements - self.cur_inline_test.check_stmts = comparisons - - - - def build_fail(self): equal_node = ast.Compare( left=ast.Constant(0), @@ -1223,13 +983,11 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) - elif call.func.attr == self.diff_test_str: - self.parse_diff_test(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") @@ -1370,8 +1128,6 @@ def _find(self, tests, obj, module, globs, seen): ###################################################################### class InlineTestRunner: def run(self, test: InlineTest, out: List) -> None: - test_str = test.to_test() - print(test_str) tree = ast.parse(test.to_test()) codeobj = compile(tree, filename="", mode="exec") start_time = time.time() diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 56dd482..40c3096 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -2,10 +2,6 @@ from _pytest.pytester import Pytester import pytest -# For testing in Spyder only -# if __name__ == "__main__": -# pytest.main(['-v', '-s']) - # pytest -p pytester class TestInlinetests: