From d6802e06ec9771b0da0ebc126c67b345c2d4a6d4 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 21 Nov 2025 14:18:18 -0500 Subject: [PATCH 1/8] Added Import Changes --- src/inline/plugin.py | 70 ++++++++++++++++++++++++++++++++++++++------ tests/test_plugin.py | 55 ++++++++++++++++++++++++---------- 2 files changed, 101 insertions(+), 24 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index f8ddfc1..d0a8436 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -296,11 +296,17 @@ class ExtractInlineTest(ast.NodeTransformer): arg_timeout_str = "timeout" assume = "assume" + inline_module_imported = False + import_str = "import" + from_str = "from" + as_str = "as" + inline_module_imported = False def __init__(self): self.cur_inline_test = InlineTest() self.inline_test_list = [] + self.import_list = [] def is_inline_test_class(self, node): if isinstance(node, ast.Call): @@ -350,16 +356,39 @@ def find_previous_stmt(self, node): return prev_stmt_node return self.find_condition_stmt(prev_stmt_node) - def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]): + def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call], import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): """ collect all function calls in the node """ if isinstance(node, ast.Attribute): - self.collect_inline_test_calls(node.value, inline_test_calls) + self.collect_inline_test_calls(node.value, inline_test_calls, import_calls, import_from_calls) elif isinstance(node, ast.Call): inline_test_calls.append(node) - self.collect_inline_test_calls(node.func, inline_test_calls) + self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) + elif isinstance(node, ast.Import): + import_calls.append(node) + self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) + elif isinstance(node, ast.ImportFrom): + import_from_calls.append(node) + self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) + + def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): + """ + collect all import calls in the node (should be done first) + """ + while not isinstance(node, ast.Module) and node.parent != None: + node = node.parent + + if not isinstance(node, ast.Module): + return + + for child in node.children: + if isinstance(child, ast.Import): + import_calls.append(child) + elif isinstance(child, ast.ImportFrom): + import_from_calls.append(child) + def parse_constructor(self, node): """ Parse a constructor call. @@ -931,8 +960,13 @@ def parse_parameterized_test(self): parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index) def parse_inline_test(self, node): - inline_test_calls = [] - self.collect_inline_test_calls(node, inline_test_calls) + import_calls = [] + import_from_calls = [] + inline_test_calls = [] + + self.collect_inline_test_calls(node, inline_test_calls, import_calls, import_from_calls) + self.collect_import_calls(node, import_calls, import_from_calls) + inline_test_calls.reverse() if len(inline_test_calls) <= 1: @@ -953,14 +987,32 @@ def parse_inline_test(self, node): self.parse_assume(call) inline_test_call_index += 1 - # "given(a, 1)" for call in inline_test_calls[inline_test_call_index:]: - if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str: - self.parse_given(call) - inline_test_call_index += 1 + if isinstance(call.func, ast.Attribute): + if call.func.attr == self.given_str: + self.parse_given(call) + inline_test_call_index += 1 + elif call.func.attr == self.diff_given_str: + self.parse_diff_given(call) + inline_test_call_index += 1 + + # match call.func.attr: + # # "given(a, 1)" + # case self.given_str: + # self.parse_given(call) + # inline_test_call_index += 1 + # # "diff_given(devices, ["cpu", "cuda"])" + # case self.diff_given_str: + # self.parse_diff_given(call) + # inline_test_call_index += 1 else: break + for import_stmt in import_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + for import_stmt in import_from_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + # "check_eq" or "check_true" or "check_false" or "check_neq" for call in inline_test_calls[inline_test_call_index:]: # "check_eq(a, 1)" diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 40c3096..0b0e7c3 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -2,9 +2,48 @@ from _pytest.pytester import Pytester import pytest +# For testing in Spyder only +if __name__ == "__main__": + pytest.main(['-v', '-s']) + # pytest -p pytester class TestInlinetests: + def test_inline_detects_imports(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + import datetime + + def m(a): + b = a + datetime.timedelta(days=365) + itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366)) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret != 1 + + # def test_inline_detects_from_imports(self, pytester: Pytester): + # checkfile = pytester.makepyfile( + # """ + # from inline import itest + # import numpy as np + # from scipy import stats as st + + # def m(n, p): + # b = st.binom(n, p) + # itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p) + # """ + # ) + # for x in (pytester.path, checkfile): + # items, reprec = pytester.inline_genitems(x) + # assert len(items) == 1 + # res = pytester.runpytest() + # assert res.ret == 0 + def test_inline_parser(self, pytester: Pytester): checkfile = pytester.makepyfile( """ @@ -31,6 +70,7 @@ def m(a): items, reprec = pytester.inline_genitems(x) assert len(items) == 0 + def test_inline_malformed_given(self, pytester: Pytester): checkfile = pytester.makepyfile( """ @@ -118,21 +158,6 @@ def m(a): res = pytester.runpytest() assert res.ret == 0 - def test_check_eq_parameterized_tests(self, pytester: Pytester): - checkfile = pytester.makepyfile( - """ - from inline import itest - def m(a): - a = a + 1 - itest(parameterized=True).given(a, [2, 3]).check_eq(a, [3, 4]) - """ - ) - for x in (pytester.path, checkfile): - items, reprec = pytester.inline_genitems(x) - assert len(items) == 2 - res = pytester.runpytest() - assert res.ret == 0 - def test_malformed_check_eq_parameterized_tests(self, pytester: Pytester): checkfile = pytester.makepyfile( """ From 0e086746c33ad6995520630b1cff54fe2b4bf89e Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 21 Nov 2025 14:22:19 -0500 Subject: [PATCH 2/8] Readded Additional Functionality for Imports --- src/inline/plugin.py | 371 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 320 insertions(+), 51 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index d0a8436..380ca20 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -159,6 +159,7 @@ def __init__(self): self.check_stmts = [] self.given_stmts = [] self.previous_stmts = [] + self.import_stmts = [] self.prev_stmt_type = PrevStmtType.StmtExpr # the line number of test statement self.lineno = 0 @@ -174,11 +175,23 @@ def __init__(self): self.devices = None self.globs = {} + def write_imports(self): + import_str = "" + for n in self.import_stmts: + import_str += ExtractInlineTest.node_to_source_code(n) + "\n" + return import_str + def to_test(self): + prefix = "\n" + + # for n in self.import_stmts: + # import_str += ExtractInlineTest.node_to_source_code(n) + "\n" + + if self.prev_stmt_type == PrevStmtType.CondExpr: if self.assume_stmts == []: - return "\n".join( - [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + return prefix.join( + + [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] ) else: @@ -187,11 +200,11 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join(ExtractInlineTest.node_to_source_code(assume_node)) + return prefix.join(ExtractInlineTest.node_to_source_code(assume_node)) else: if self.assume_stmts is None or self.assume_stmts == []: - return "\n".join( + return prefix.join( [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.previous_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] @@ -202,7 +215,7 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join([ExtractInlineTest.node_to_source_code(assume_node)]) + return prefix.join([ExtractInlineTest.node_to_source_code(assume_node)]) def build_assume_node(self, assumption_node, body_nodes): return ast.If(assumption_node, body_nodes, []) @@ -252,7 +265,7 @@ class TimeoutException(Exception): ## InlineTest Parser ###################################################################### class InlinetestParser: - def parse(self, obj, globs: None): + def parse(self, obj, globs: None): # obj = open(self.file_path, "r").read(): if isinstance(obj, ModuleType): tree = ast.parse(open(obj.__file__, "r").read()) @@ -294,9 +307,10 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" - + arg_devices_str = "devices" + diff_test_str = "diff_test" assume = "assume" - inline_module_imported = False + import_str = "import" from_str = "from" as_str = "as" @@ -388,7 +402,7 @@ def collect_import_calls(self, node, import_calls: List[ast.Import], import_from import_calls.append(child) elif isinstance(child, ast.ImportFrom): import_from_calls.append(child) - + def parse_constructor(self, node): """ Parse a constructor call. @@ -412,9 +426,10 @@ def parse_constructor(self, node): self.arg_tag_str : 3, self.arg_disabled_str : 4, self.arg_timeout_str : 5, + self.arg_devices_str : 6 } - NUM_OF_ARGUMENTS = 6 + NUM_OF_ARGUMENTS = 7 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments self.parse_constructor_args(node.args) @@ -445,6 +460,7 @@ class ConstrArgs(enum.Enum): TAG_STR = 3 DISABLED = 4 TIMEOUT = 5 + DEVICES = 6 property_names = { ConstrArgs.TEST_NAME : "test_name", @@ -453,6 +469,7 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : "tag", ConstrArgs.DISABLED : "disabled", ConstrArgs.TIMEOUT : "timeout", + ConstrArgs.DEVICES : "devices" } pre_38_val_names = { @@ -462,6 +479,7 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : "s", ConstrArgs.DISABLED : "value", ConstrArgs.TIMEOUT : "n", + ConstrArgs.DEVICES : "" } pre_38_expec_ast_arg_type = { @@ -489,9 +507,10 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : [None], ConstrArgs.DISABLED : [bool], ConstrArgs.TIMEOUT : [float, int], + ConstrArgs.DEVICES : [str] } - NUM_OF_ARGUMENTS = 6 + NUM_OF_ARGUMENTS = 7 # Arguments organized by expected ast type, value type, and index in that order for index, arg in enumerate(args): @@ -499,7 +518,16 @@ class ConstrArgs(enum.Enum): if arg == None: continue - + # Devices are not referenced in versions before 3.8; all other arguments can be from any version + if index == ConstrArgs.DEVICES and isinstance(arg, ast.List): + devices = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException("devices can only be List of string") + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + self.cur_inline_test.devices = devices # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8 else: corr_arg_type = False @@ -519,30 +547,22 @@ class ConstrArgs(enum.Enum): if arg_type == None: corr_val_type = True break - if isinstance(getattr(arg, value_prop_name), arg_type): + if isinstance(arg.value, arg_type): corr_val_type = True break if corr_val_type and corr_arg_type: # Accounts for additional checks for REPEATED and TAG_STR arguments if arg_idx == ConstrArgs.REPEATED: - value = getattr(arg, value_prop_name) - if value <= 0: + if arg.value <= 0: raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = value + self.cur_inline_test.repeated = getattr(arg, value_prop_name) elif arg_idx == ConstrArgs.TAG_STR: tags = [] - - if sys.version_info < (3, 8, 0): - elt_type = ast.Str - else: - elt_type = ast.Constant - for elt in arg.elts: - value = getattr(elt, value_prop_name) - if (not isinstance(elt, elt_type) and isinstance(value, str)): + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): raise MalformedException(f"tag can only be List of string") - tags.append(value) + tags.append(getattr(elt, value_prop_name)) self.cur_inline_test.tag = tags # For non-special cases, set the attribute defined by the dictionary else: @@ -551,7 +571,6 @@ class ConstrArgs(enum.Enum): getattr(arg, value_prop_name)) - ## Match implementation of above conditional tree; commented since Python < 3.10 does not support match # match arg_idx: # case ConstrArgs.REPEATED: @@ -574,7 +593,9 @@ class ConstrArgs(enum.Enum): raise MalformedException( f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" ) - + #raise MalformedException("Argument " + str(index) + " incorrectly formatted. Argument should be a " + ConstrArgs.expected_ast_val_args[index].type()) + + def parameterized_inline_tests_init(self, node: ast.List): if not self.cur_inline_test.parameterized_inline_tests: self.cur_inline_test.parameterized_inline_tests = [InlineTest() for _ in range(len(node.elts))] @@ -912,6 +933,240 @@ def parse_check_not_same(self, node): else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") + def parse_diff_test(self, node): + if not self.cur_inline_test.devices: + raise MalformedException("diff_test can only be used with the 'devices' parameter.") + + if len(node.args) != 1: + raise MalformedException("diff_test() requires exactly 1 argument.") + + output_node = self.parse_group(node.args[0]) + + # Get the original operation + original_op = None + for stmt in self.cur_inline_test.previous_stmts: + if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: + original_op = stmt.value + break + + if not original_op: + raise MalformedException("Could not find original operation for diff_test") + + # Create our new statements + new_statements = [] + device_outputs = [] + + # Import necessary modules for seed setting - Always add these + # Import random module + import_random = ast.ImportFrom( + module='random', + names=[ast.alias(name='seed', asname=None)], + level=0 + ) + new_statements.append(import_random) + + # Import numpy.random + import_np = ast.ImportFrom( + module='numpy', + names=[ast.alias(name='random', asname='np_random')], + level=0 + ) + new_statements.append(import_np) + + # Create seed function - Always add this + seed_func_def = ast.FunctionDef( + name='set_random_seed', + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg='seed_value', annotation=None)], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Name(id='seed', ctx=ast.Load()), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='manual_seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='np_random', ctx=ast.Load()), + attr='seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ) + ], + decorator_list=[], + returns=None + ) + new_statements.append(seed_func_def) + + # Process input tensors + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + ref_var = f"{input_var}_ref" + + # Always clone inputs for in-place operations + new_statements.append( + ast.Assign( + targets=[ast.Name(id=ref_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=given_stmt.value, + attr="clone" + ), + args=[], + keywords=[] + ) + ) + ) + + # Create device-specific versions + for device in self.cur_inline_test.devices: + device_var = f"{input_var}_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=ref_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + ) + + # Create device-specific operations + device_input_map = {device: {} for device in self.cur_inline_test.devices} + for device in self.cur_inline_test.devices: + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + device_input_map[device][input_var] = f"{input_var}_{device}" + + # Always set seed before each device operation - no condition check + new_statements.append( + ast.Expr( + value=ast.Call( + func=ast.Name(id='set_random_seed', ctx=ast.Load()), + args=[ast.Constant(value=42)], # Use constant seed 42 + keywords=[] + ) + ) + ) + + device_op = copy.deepcopy(original_op) + + # Replace input references + class ReplaceInputs(ast.NodeTransformer): + def visit_Name(self, node): + if node.id in device_input_map[device]: + return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) + return node + + device_op = ReplaceInputs().visit(device_op) + device_output = f"output_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_output, ctx=ast.Store())], + value=device_op + ) + ) + device_outputs.append(device_output) + + # Standard comparison method for all operations - no condition check + comparisons = [] + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] + ), + ast.Constant(value=True) + ) + comparisons.append(comparison) + + # Replace statements + self.cur_inline_test.previous_stmts = new_statements + self.cur_inline_test.check_stmts = comparisons + + def parse_import(self, node): + # TODO: Differentiate between import, from import, and import alias + import_node = ast.Import( + names=[ + ast.alias(name=node) + ] + ) + return import_node + + def parse_import_from(self, node): + pass + def build_fail(self): equal_node = ast.Compare( left=ast.Constant(0), @@ -951,6 +1206,7 @@ def parse_group(self, node): return stmt else: return node + def parse_parameterized_test(self): for index, parameterized_test in enumerate(self.cur_inline_test.parameterized_inline_tests): @@ -987,24 +1243,11 @@ def parse_inline_test(self, node): self.parse_assume(call) inline_test_call_index += 1 + # "given(a, 1)" for call in inline_test_calls[inline_test_call_index:]: - if isinstance(call.func, ast.Attribute): - if call.func.attr == self.given_str: - self.parse_given(call) - inline_test_call_index += 1 - elif call.func.attr == self.diff_given_str: - self.parse_diff_given(call) - inline_test_call_index += 1 - - # match call.func.attr: - # # "given(a, 1)" - # case self.given_str: - # self.parse_given(call) - # inline_test_call_index += 1 - # # "diff_given(devices, ["cpu", "cuda"])" - # case self.diff_given_str: - # self.parse_diff_given(call) - # inline_test_call_index += 1 + if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str: + self.parse_given(call) + inline_test_call_index += 1 else: break @@ -1013,6 +1256,7 @@ def parse_inline_test(self, node): for import_stmt in import_from_calls: self.cur_inline_test.import_stmts.append(import_stmt) + # "check_eq" or "check_true" or "check_false" or "check_neq" for call in inline_test_calls[inline_test_call_index:]: # "check_eq(a, 1)" @@ -1035,11 +1279,13 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) + elif call.func.attr == self.diff_test_str: + self.parse_diff_test(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") @@ -1073,6 +1319,7 @@ def node_to_source_code(node): ## InlineTest Finder ###################################################################### class InlineTestFinder: + # Finder should NOT store any global variables def __init__(self, parser=InlinetestParser(), recurse=True, exclude_empty=True): self._parser = parser self._recurse = recurse @@ -1117,7 +1364,14 @@ def _is_routine(self, obj): pass return inspect.isroutine(maybe_routine) - def find(self, obj, module=None, globs=None, extraglobs=None): + # def find_imports(self, obj, module=None): + # if module is False: + # module = None + # elif module is None: + # module = inspect.getmodule(obj) + + + def find(self, obj, module=None, globs=None, extraglobs=None, imports=None): # Find the module that contains the given object (if obj is # a module, then module=obj.). if module is False: @@ -1138,15 +1392,23 @@ def find(self, obj, module=None, globs=None, extraglobs=None): if "__name__" not in globs: globs["__name__"] = "__main__" # provide a default module name + # Find intersection between loaded modules and module imports + # if imports is None: + # imports = set(sys.modules) & set(globs) + # else: + # imports = imports.copy() + # Recursively explore `obj`, extracting InlineTests. tests = [] - self._find(tests, obj, module, globs, {}) + self._find(tests, obj, module, globs, imports, {}) return tests - def _find(self, tests, obj, module, globs, seen): + def _find(self, tests, obj, module, globs, imports, seen): if id(obj) in seen: return seen[id(obj)] = 1 + + # Find a test for this object, and add it to the list of tests. test = self._parser.parse(obj, globs) if test is not None: @@ -1158,7 +1420,7 @@ def _find(self, tests, obj, module, globs, seen): # Recurse to functions & classes. if (self._is_routine(val) or inspect.isclass(val)) and self._from_module(module, val): - self._find(tests, val, module, globs, seen) + self._find(tests, val, module, globs, imports, seen) # Look for tests in a class's contained objects. if inspect.isclass(obj) and self._recurse: @@ -1172,7 +1434,7 @@ def _find(self, tests, obj, module, globs, seen): module, val ): valname = "%s" % (valname) - self._find(tests, val, module, globs, seen) + self._find(tests, val, module, globs, imports, seen) ###################################################################### @@ -1180,7 +1442,10 @@ def _find(self, tests, obj, module, globs, seen): ###################################################################### class InlineTestRunner: def run(self, test: InlineTest, out: List) -> None: - tree = ast.parse(test.to_test()) + test_str = test.write_imports() + test_str += test.to_test() + print(test_str) + tree = ast.parse(test_str) codeobj = compile(tree, filename="", mode="exec") start_time = time.time() if test.timeout > 0: @@ -1317,6 +1582,10 @@ def collect(self) -> Iterable[InlinetestItem]: group_tags = self.config.getoption("inlinetest_group", default=None) order_tags = self.config.getoption("inlinetest_order", default=None) + # TODO: import all modules through the finder first before extracting inline tests + # - Create ast for all imports + # - If a function references an import, then include the imported library reference in the ast + for test_list in finder.find(module): # reorder the list if there are tests to be ordered ordered_list = InlinetestModule.order_tests(test_list, order_tags) From 7bc3be59161a8655dec972e855300c82e3703b84 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 21 Nov 2025 14:26:07 -0500 Subject: [PATCH 3/8] Removed Spyder Testing Line --- tests/test_plugin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 0b0e7c3..abdc0af 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -2,9 +2,9 @@ from _pytest.pytester import Pytester import pytest -# For testing in Spyder only -if __name__ == "__main__": - pytest.main(['-v', '-s']) +# # For testing in Spyder only +# if __name__ == "__main__": +# pytest.main(['-v', '-s']) # pytest -p pytester From 19fcf7c6a9095ea86102d65bf4595441183b8701 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 13:02:24 -0500 Subject: [PATCH 4/8] Added Import Functionality --- src/inline/plugin.py | 351 +++++-------------------------------------- 1 file changed, 36 insertions(+), 315 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 380ca20..f2e3189 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -184,14 +184,10 @@ def write_imports(self): def to_test(self): prefix = "\n" - # for n in self.import_stmts: - # import_str += ExtractInlineTest.node_to_source_code(n) + "\n" - - if self.prev_stmt_type == PrevStmtType.CondExpr: if self.assume_stmts == []: return prefix.join( - + [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] ) else: @@ -265,7 +261,7 @@ class TimeoutException(Exception): ## InlineTest Parser ###################################################################### class InlinetestParser: - def parse(self, obj, globs: None): + def parse(self, obj, globs: None): # obj = open(self.file_path, "r").read(): if isinstance(obj, ModuleType): tree = ast.parse(open(obj.__file__, "r").read()) @@ -307,8 +303,7 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" - arg_devices_str = "devices" - diff_test_str = "diff_test" + assume = "assume" import_str = "import" @@ -320,7 +315,6 @@ class ExtractInlineTest(ast.NodeTransformer): def __init__(self): self.cur_inline_test = InlineTest() self.inline_test_list = [] - self.import_list = [] def is_inline_test_class(self, node): if isinstance(node, ast.Call): @@ -370,21 +364,15 @@ def find_previous_stmt(self, node): return prev_stmt_node return self.find_condition_stmt(prev_stmt_node) - def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call], import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): + def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]): """ collect all function calls in the node """ if isinstance(node, ast.Attribute): - self.collect_inline_test_calls(node.value, inline_test_calls, import_calls, import_from_calls) + self.collect_inline_test_calls(node.value, inline_test_calls) elif isinstance(node, ast.Call): inline_test_calls.append(node) - self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) - elif isinstance(node, ast.Import): - import_calls.append(node) - self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) - elif isinstance(node, ast.ImportFrom): - import_from_calls.append(node) - self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) + self.collect_inline_test_calls(node.func, inline_test_calls) def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): """ @@ -426,10 +414,9 @@ def parse_constructor(self, node): self.arg_tag_str : 3, self.arg_disabled_str : 4, self.arg_timeout_str : 5, - self.arg_devices_str : 6 } - NUM_OF_ARGUMENTS = 7 + NUM_OF_ARGUMENTS = 6 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments self.parse_constructor_args(node.args) @@ -460,7 +447,6 @@ class ConstrArgs(enum.Enum): TAG_STR = 3 DISABLED = 4 TIMEOUT = 5 - DEVICES = 6 property_names = { ConstrArgs.TEST_NAME : "test_name", @@ -469,7 +455,6 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : "tag", ConstrArgs.DISABLED : "disabled", ConstrArgs.TIMEOUT : "timeout", - ConstrArgs.DEVICES : "devices" } pre_38_val_names = { @@ -479,7 +464,6 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : "s", ConstrArgs.DISABLED : "value", ConstrArgs.TIMEOUT : "n", - ConstrArgs.DEVICES : "" } pre_38_expec_ast_arg_type = { @@ -507,10 +491,9 @@ class ConstrArgs(enum.Enum): ConstrArgs.TAG_STR : [None], ConstrArgs.DISABLED : [bool], ConstrArgs.TIMEOUT : [float, int], - ConstrArgs.DEVICES : [str] } - NUM_OF_ARGUMENTS = 7 + NUM_OF_ARGUMENTS = 6 # Arguments organized by expected ast type, value type, and index in that order for index, arg in enumerate(args): @@ -518,16 +501,7 @@ class ConstrArgs(enum.Enum): if arg == None: continue - # Devices are not referenced in versions before 3.8; all other arguments can be from any version - if index == ConstrArgs.DEVICES and isinstance(arg, ast.List): - devices = [] - for elt in arg.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException("devices can only be List of string") - if elt.value not in {"cpu", "cuda", "mps"}: - raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") - devices.append(elt.value) - self.cur_inline_test.devices = devices + # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8 else: corr_arg_type = False @@ -547,22 +521,30 @@ class ConstrArgs(enum.Enum): if arg_type == None: corr_val_type = True break - if isinstance(arg.value, arg_type): + if isinstance(getattr(arg, value_prop_name), arg_type): corr_val_type = True break if corr_val_type and corr_arg_type: # Accounts for additional checks for REPEATED and TAG_STR arguments if arg_idx == ConstrArgs.REPEATED: - if arg.value <= 0: + value = getattr(arg, value_prop_name) + if value <= 0: raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = getattr(arg, value_prop_name) + self.cur_inline_test.repeated = value elif arg_idx == ConstrArgs.TAG_STR: tags = [] + + if sys.version_info < (3, 8, 0): + elt_type = ast.Str + else: + elt_type = ast.Constant + for elt in arg.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + value = getattr(elt, value_prop_name) + if (not isinstance(elt, elt_type) and isinstance(value, str)): raise MalformedException(f"tag can only be List of string") - tags.append(getattr(elt, value_prop_name)) + tags.append(value) self.cur_inline_test.tag = tags # For non-special cases, set the attribute defined by the dictionary else: @@ -571,6 +553,7 @@ class ConstrArgs(enum.Enum): getattr(arg, value_prop_name)) + ## Match implementation of above conditional tree; commented since Python < 3.10 does not support match # match arg_idx: # case ConstrArgs.REPEATED: @@ -593,9 +576,7 @@ class ConstrArgs(enum.Enum): raise MalformedException( f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" ) - #raise MalformedException("Argument " + str(index) + " incorrectly formatted. Argument should be a " + ConstrArgs.expected_ast_val_args[index].type()) - - + def parameterized_inline_tests_init(self, node: ast.List): if not self.cur_inline_test.parameterized_inline_tests: self.cur_inline_test.parameterized_inline_tests = [InlineTest() for _ in range(len(node.elts))] @@ -933,240 +914,6 @@ def parse_check_not_same(self, node): else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") - def parse_diff_test(self, node): - if not self.cur_inline_test.devices: - raise MalformedException("diff_test can only be used with the 'devices' parameter.") - - if len(node.args) != 1: - raise MalformedException("diff_test() requires exactly 1 argument.") - - output_node = self.parse_group(node.args[0]) - - # Get the original operation - original_op = None - for stmt in self.cur_inline_test.previous_stmts: - if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: - original_op = stmt.value - break - - if not original_op: - raise MalformedException("Could not find original operation for diff_test") - - # Create our new statements - new_statements = [] - device_outputs = [] - - # Import necessary modules for seed setting - Always add these - # Import random module - import_random = ast.ImportFrom( - module='random', - names=[ast.alias(name='seed', asname=None)], - level=0 - ) - new_statements.append(import_random) - - # Import numpy.random - import_np = ast.ImportFrom( - module='numpy', - names=[ast.alias(name='random', asname='np_random')], - level=0 - ) - new_statements.append(import_np) - - # Create seed function - Always add this - seed_func_def = ast.FunctionDef( - name='set_random_seed', - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg='seed_value', annotation=None)], - kwonlyargs=[], - kw_defaults=[], - defaults=[] - ), - body=[ - ast.Expr( - value=ast.Call( - func=ast.Name(id='seed', ctx=ast.Load()), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='torch', ctx=ast.Load()), - attr='manual_seed' - ), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='np_random', ctx=ast.Load()), - attr='seed' - ), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ) - ], - decorator_list=[], - returns=None - ) - new_statements.append(seed_func_def) - - # Process input tensors - for given_stmt in self.cur_inline_test.given_stmts: - input_var = given_stmt.targets[0].id - ref_var = f"{input_var}_ref" - - # Always clone inputs for in-place operations - new_statements.append( - ast.Assign( - targets=[ast.Name(id=ref_var, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=given_stmt.value, - attr="clone" - ), - args=[], - keywords=[] - ) - ) - ) - - # Create device-specific versions - for device in self.cur_inline_test.devices: - device_var = f"{input_var}_{device}" - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=device_var, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=ref_var, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value=device)], - keywords=[] - ) - ) - ) - - # Create device-specific operations - device_input_map = {device: {} for device in self.cur_inline_test.devices} - for device in self.cur_inline_test.devices: - for given_stmt in self.cur_inline_test.given_stmts: - input_var = given_stmt.targets[0].id - device_input_map[device][input_var] = f"{input_var}_{device}" - - # Always set seed before each device operation - no condition check - new_statements.append( - ast.Expr( - value=ast.Call( - func=ast.Name(id='set_random_seed', ctx=ast.Load()), - args=[ast.Constant(value=42)], # Use constant seed 42 - keywords=[] - ) - ) - ) - - device_op = copy.deepcopy(original_op) - - # Replace input references - class ReplaceInputs(ast.NodeTransformer): - def visit_Name(self, node): - if node.id in device_input_map[device]: - return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) - return node - - device_op = ReplaceInputs().visit(device_op) - device_output = f"output_{device}" - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=device_output, ctx=ast.Store())], - value=device_op - ) - ) - device_outputs.append(device_output) - - # Standard comparison method for all operations - no condition check - comparisons = [] - for i in range(len(device_outputs) - 1): - dev1 = device_outputs[i] - dev2 = device_outputs[i + 1] - - dev1_cpu = f"{dev1}_cpu" - dev2_cpu = f"{dev2}_cpu" - - # Move outputs back to CPU for comparison - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev1, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) - ) - ) - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev2, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) - ) - ) - - # Standard allclose comparison - comparison = self.build_assert_eq( - ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev1_cpu, ctx=ast.Load()), - attr="allclose" - ), - args=[ - ast.Name(id=dev2_cpu, ctx=ast.Load()) - ], - keywords=[ - ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) - ] - ), - ast.Constant(value=True) - ) - comparisons.append(comparison) - - # Replace statements - self.cur_inline_test.previous_stmts = new_statements - self.cur_inline_test.check_stmts = comparisons - - def parse_import(self, node): - # TODO: Differentiate between import, from import, and import alias - import_node = ast.Import( - names=[ - ast.alias(name=node) - ] - ) - return import_node - - def parse_import_from(self, node): - pass - def build_fail(self): equal_node = ast.Compare( left=ast.Constant(0), @@ -1206,7 +953,6 @@ def parse_group(self, node): return stmt else: return node - def parse_parameterized_test(self): for index, parameterized_test in enumerate(self.cur_inline_test.parameterized_inline_tests): @@ -1220,7 +966,7 @@ def parse_inline_test(self, node): import_from_calls = [] inline_test_calls = [] - self.collect_inline_test_calls(node, inline_test_calls, import_calls, import_from_calls) + self.collect_inline_test_calls(node, inline_test_calls) self.collect_import_calls(node, import_calls, import_from_calls) inline_test_calls.reverse() @@ -1243,11 +989,11 @@ def parse_inline_test(self, node): self.parse_assume(call) inline_test_call_index += 1 - # "given(a, 1)" for call in inline_test_calls[inline_test_call_index:]: - if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str: - self.parse_given(call) - inline_test_call_index += 1 + if isinstance(call.func, ast.Attribute): + if call.func.attr == self.given_str: + self.parse_given(call) + inline_test_call_index += 1 else: break @@ -1279,13 +1025,11 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) - elif call.func.attr == self.diff_test_str: - self.parse_diff_test(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") @@ -1319,7 +1063,6 @@ def node_to_source_code(node): ## InlineTest Finder ###################################################################### class InlineTestFinder: - # Finder should NOT store any global variables def __init__(self, parser=InlinetestParser(), recurse=True, exclude_empty=True): self._parser = parser self._recurse = recurse @@ -1364,14 +1107,7 @@ def _is_routine(self, obj): pass return inspect.isroutine(maybe_routine) - # def find_imports(self, obj, module=None): - # if module is False: - # module = None - # elif module is None: - # module = inspect.getmodule(obj) - - - def find(self, obj, module=None, globs=None, extraglobs=None, imports=None): + def find(self, obj, module=None, globs=None, extraglobs=None): # Find the module that contains the given object (if obj is # a module, then module=obj.). if module is False: @@ -1392,23 +1128,15 @@ def find(self, obj, module=None, globs=None, extraglobs=None, imports=None): if "__name__" not in globs: globs["__name__"] = "__main__" # provide a default module name - # Find intersection between loaded modules and module imports - # if imports is None: - # imports = set(sys.modules) & set(globs) - # else: - # imports = imports.copy() - # Recursively explore `obj`, extracting InlineTests. tests = [] - self._find(tests, obj, module, globs, imports, {}) + self._find(tests, obj, module, globs, {}) return tests - def _find(self, tests, obj, module, globs, imports, seen): + def _find(self, tests, obj, module, globs, seen): if id(obj) in seen: return seen[id(obj)] = 1 - - # Find a test for this object, and add it to the list of tests. test = self._parser.parse(obj, globs) if test is not None: @@ -1420,7 +1148,7 @@ def _find(self, tests, obj, module, globs, imports, seen): # Recurse to functions & classes. if (self._is_routine(val) or inspect.isclass(val)) and self._from_module(module, val): - self._find(tests, val, module, globs, imports, seen) + self._find(tests, val, module, globs, seen) # Look for tests in a class's contained objects. if inspect.isclass(obj) and self._recurse: @@ -1434,7 +1162,7 @@ def _find(self, tests, obj, module, globs, imports, seen): module, val ): valname = "%s" % (valname) - self._find(tests, val, module, globs, imports, seen) + self._find(tests, val, module, globs, seen) ###################################################################### @@ -1442,10 +1170,7 @@ def _find(self, tests, obj, module, globs, imports, seen): ###################################################################### class InlineTestRunner: def run(self, test: InlineTest, out: List) -> None: - test_str = test.write_imports() - test_str += test.to_test() - print(test_str) - tree = ast.parse(test_str) + tree = ast.parse(test.to_test()) codeobj = compile(tree, filename="", mode="exec") start_time = time.time() if test.timeout > 0: @@ -1582,10 +1307,6 @@ def collect(self) -> Iterable[InlinetestItem]: group_tags = self.config.getoption("inlinetest_group", default=None) order_tags = self.config.getoption("inlinetest_order", default=None) - # TODO: import all modules through the finder first before extracting inline tests - # - Create ast for all imports - # - If a function references an import, then include the imported library reference in the ast - for test_list in finder.find(module): # reorder the list if there are tests to be ordered ordered_list = InlinetestModule.order_tests(test_list, order_tags) From 7fbcdca8d409b4d79ea5005a5e79ae2161e853ab Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 13:44:02 -0500 Subject: [PATCH 5/8] Added Diff Given and Added Back Diff Test Functionality --- src/inline/plugin.py | 254 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 253 insertions(+), 1 deletion(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index f2e3189..1288bec 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -296,6 +296,7 @@ class ExtractInlineTest(ast.NodeTransformer): check_not_same = "check_not_same" fail_str = "fail" given_str = "given" + diff_given_str = "diff_given" group_str = "Group" arg_test_name_str = "test_name" arg_parameterized_str = "parameterized" @@ -303,6 +304,8 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" + arg_devices_str = "devices" + diff_test_str = "diff_test" assume = "assume" @@ -596,6 +599,23 @@ def parse_given(self, node): else: raise MalformedException("inline test: invalid given(), expected 2 args") + def parse_diff_given(self, node): + PROPERTY = 0 + VALUES = 1 + + if len(node.args) == 2: + if self.cur_inline_test.parameterized: + raise MalformedException("inline test: Parameterized inline tests currently do not support differential tests.") + else: + devices = [] + for elt in node.args[VALUES].elts: + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + setattr(self.cur_inline_test, node.args[PROPERTY].id, devices) + else: + raise MalformedException("inline test: invalid diff_given(), expected 2 args") + def parse_assume(self, node): if len(node.args) == 1: if self.cur_inline_test.parameterized: @@ -930,6 +950,229 @@ def parse_fail(self, node): else: raise MalformedException("inline test: fail() does not expect any arguments") + def parse_diff_test(self, node): + if not self.cur_inline_test.devices: + raise MalformedException("diff_test can only be used with the 'devices' parameter.") + + if len(node.args) != 1: + raise MalformedException("diff_test() requires exactly 1 argument.") + + output_node = self.parse_group(node.args[0]) + + # Get the original operation + original_op = None + for stmt in self.cur_inline_test.previous_stmts: + if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: + original_op = stmt.value + break + + if not original_op: + raise MalformedException("Could not find original operation for diff_test") + + # Create our new statements + new_statements = [] + device_outputs = [] + + # Import necessary modules for seed setting - Always add these + # Import random module + import_random = ast.ImportFrom( + module='random', + names=[ast.alias(name='seed', asname=None)], + level=0 + ) + new_statements.append(import_random) + + # Import numpy.random + import_np = ast.ImportFrom( + module='numpy', + names=[ast.alias(name='random', asname='np_random')], + level=0 + ) + new_statements.append(import_np) + + # Create seed function - Always add this + seed_func_def = ast.FunctionDef( + name='set_random_seed', + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg='seed_value', annotation=None)], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Name(id='seed', ctx=ast.Load()), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='manual_seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='np_random', ctx=ast.Load()), + attr='seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ) + ], + decorator_list=[], + returns=None + ) + new_statements.append(seed_func_def) + + # Process input tensors + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + ref_var = f"{input_var}_ref" + + # Always clone inputs for in-place operations + new_statements.append( + ast.Assign( + targets=[ast.Name(id=ref_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=given_stmt.value, + attr="clone" + ), + args=[], + keywords=[] + ) + ) + ) + + # Create device-specific versions + for device in self.cur_inline_test.devices: + device_var = f"{input_var}_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=ref_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + ) + + # Create device-specific operations + device_input_map = {device: {} for device in self.cur_inline_test.devices} + for device in self.cur_inline_test.devices: + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + device_input_map[device][input_var] = f"{input_var}_{device}" + + # Always set seed before each device operation - no condition check + new_statements.append( + ast.Expr( + value=ast.Call( + func=ast.Name(id='set_random_seed', ctx=ast.Load()), + args=[ast.Constant(value=42)], # Use constant seed 42 + keywords=[] + ) + ) + ) + + device_op = copy.deepcopy(original_op) + + # Replace input references + class ReplaceInputs(ast.NodeTransformer): + def visit_Name(self, node): + if node.id in device_input_map[device]: + return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) + return node + + device_op = ReplaceInputs().visit(device_op) + device_output = f"output_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_output, ctx=ast.Store())], + value=device_op + ) + ) + device_outputs.append(device_output) + + # Standard comparison method for all operations - no condition check + comparisons = [] + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] + ), + ast.Constant(value=True) + ) + comparisons.append(comparison) + + # Replace statements + self.cur_inline_test.previous_stmts = new_statements + self.cur_inline_test.check_stmts = comparisons + + def parse_group(self, node): if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == self.group_str: # node type is ast.Call, node.func type is ast.Name @@ -988,6 +1231,9 @@ def parse_inline_test(self, node): if isinstance(call.func, ast.Attribute) and call.func.attr == self.assume: self.parse_assume(call) inline_test_call_index += 1 + elif call.func.attr == self.diff_given_str: + self.parse_diff_given(call) + inline_test_call_index += 1 for call in inline_test_calls[inline_test_call_index:]: if isinstance(call.func, ast.Attribute): @@ -1027,9 +1273,15 @@ def parse_inline_test(self, node): self.parse_check_not_same(call) elif call.func.attr == self.fail_str: self.parse_fail(call) + elif call.func.attr == self.diff_test_str: + self.parse_diff_test(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" + ) + elif call.func.attr == self.diff_given_str: + raise MalformedException( + f"inline test: diff_given() must be called before check_eq()/check_true()/check_false()/diff_test()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") From 552a00ed2da9f0fecbce19aca023cf8ca2002ba3 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 13:50:02 -0500 Subject: [PATCH 6/8] Fixed Diff Given Conditional Placement --- src/inline/plugin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 1288bec..a9a3b4e 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -1231,15 +1231,15 @@ def parse_inline_test(self, node): if isinstance(call.func, ast.Attribute) and call.func.attr == self.assume: self.parse_assume(call) inline_test_call_index += 1 - elif call.func.attr == self.diff_given_str: - self.parse_diff_given(call) - inline_test_call_index += 1 for call in inline_test_calls[inline_test_call_index:]: if isinstance(call.func, ast.Attribute): if call.func.attr == self.given_str: self.parse_given(call) inline_test_call_index += 1 + elif call.func.attr == self.diff_given_str: + self.parse_diff_given(call) + inline_test_call_index += 1 else: break From 7eba8cab0e7246073da426e138d76c95e5528119 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 15:20:39 -0500 Subject: [PATCH 7/8] Fixed Edge Case for Diff Given < 3.7 --- src/inline/plugin.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index a9a3b4e..8de75cc 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -603,15 +603,22 @@ def parse_diff_given(self, node): PROPERTY = 0 VALUES = 1 + if sys.version_info >= (3, 8, 0): + attr_name = "value" + else: + attr_name = "s" + + if len(node.args) == 2: if self.cur_inline_test.parameterized: raise MalformedException("inline test: Parameterized inline tests currently do not support differential tests.") else: devices = [] for elt in node.args[VALUES].elts: - if elt.value not in {"cpu", "cuda", "mps"}: + value = getattr(elt, attr_name) + if value not in {"cpu", "cuda", "mps"}: raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") - devices.append(elt.value) + devices.append(value) setattr(self.cur_inline_test, node.args[PROPERTY].id, devices) else: raise MalformedException("inline test: invalid diff_given(), expected 2 args") From 7ab14fc2552199535a6060d8f21ec352aa6e04eb Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 5 Dec 2025 14:06:34 -0500 Subject: [PATCH 8/8] Removed Testing Line Again --- tests/test_plugin.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index abdc0af..99f3eba 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -2,10 +2,6 @@ from _pytest.pytester import Pytester import pytest -# # For testing in Spyder only -# if __name__ == "__main__": -# pytest.main(['-v', '-s']) - # pytest -p pytester class TestInlinetests: