Skip to content

Commit d48d457

Browse files
tcauthfacebook-github-bot
authored andcommitted
Add MultiProcessMock for cross-process testing
Summary: **WHY:** Distributed testing in TorchRec requires the ability to mock functions and methods across multiple processes. Without this capability, tests that spawn multiple processes cannot easily mock expensive operations, external dependencies, or specific behaviors, making it difficult to write comprehensive unit tests for distributed PyTorch code. This leads to either incomplete test coverage or tests that require actual expensive resources. This change is also for T245908194 **WHAT:** 1. Created `MultiProcessMock` class in `/data/sandcastle/boxes/fbsource/fbcode/torchrec/distributed/test_utils/multi_process.py`: - Added `mocks` list to maintain mock configurations - Added `add_mock(target, return_value, side_effect, **kwargs)` to register new mocks - Added `apply_mocks()` to apply all registered mocks in child processes - Added `clear_mocks()` to clear all registered mocks 2. Integrated `MultiProcessMock` with `MultiProcessTestBase`: - Added `_mock_manager` instance variable initialized in `__init__` - Added public `add_mock(...)` method for users to register mocks - Created `_callable_wrapper_with_mocks()` static method to apply mocks before calling test functions - Modified `_run_multi_process_test()` to pass mock manager to child processes and use the wrapper - Modified `_run_multi_process_test_per_rank()` to pass mock manager to child processes and use the wrapper 3. Added comprehensive test suite in `/data/sandcastle/boxes/fbsource/fbcode/torchrec/distributed/tests/test_multi_process_mock.py`: - 6 unit tests for `MultiProcessMock` class functionality - 6 integration tests including baseline tests to ensure backward compatibility - Added BUCK target for the new test file **TEST:** All 12 tests pass successfully: - Unit tests verify mock storage, multiple mocks, side effects, clearing, and application - Integration tests verify cross-process mocking with return values, multiple mocks, and side effects - Baseline tests confirm backward compatibility - existing tests without mocks continue to work exactly as before - Fixed pickling issue by using module-level functions for side effects (required for multiprocessing) Ran: `buck2 test torchrec/distributed/tests:test_multi_process_mock` Result: Tests finished: Pass 12. Fail 0. Fatal 0. Skip 0. Differential Revision: D88696058
1 parent 7b3effd commit d48d457

File tree

2 files changed

+443
-2
lines changed

2 files changed

+443
-2
lines changed

torchrec/distributed/test_utils/multi_process.py

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import unittest
1616
from typing import Any, Callable, Dict, List, Optional
17+
from unittest.mock import patch
1718

1819
import torch
1920
import torch.distributed as dist
@@ -25,6 +26,65 @@
2526
)
2627

2728

29+
class MultiProcessMock:
30+
"""
31+
Manages cross-process mocks for multi-process testing.
32+
33+
This class maintains a collection of mocks that can be applied across
34+
different processes in distributed testing scenarios.
35+
"""
36+
37+
def __init__(self) -> None:
38+
self.mocks: List[Dict[str, Any]] = []
39+
40+
def add_mock(
41+
self,
42+
target: str,
43+
return_value: Any = None,
44+
side_effect: Any = None,
45+
**kwargs: Any,
46+
) -> None:
47+
"""
48+
Add a new cross-process mock.
49+
50+
Args:
51+
target: The target to mock (e.g., 'module.function')
52+
return_value: The return value for the mock
53+
side_effect: The side effect for the mock
54+
**kwargs: Additional arguments to pass to the mock
55+
"""
56+
mock_config = {
57+
"target": target,
58+
"return_value": return_value,
59+
"side_effect": side_effect,
60+
**kwargs,
61+
}
62+
self.mocks.append(mock_config)
63+
64+
def apply_mocks(self) -> List[Any]:
65+
"""
66+
Apply all registered mocks and return context managers.
67+
68+
Returns:
69+
List of active mock context managers
70+
"""
71+
active_patches = []
72+
for mock_config in self.mocks:
73+
target = mock_config["target"]
74+
return_value = mock_config.get("return_value")
75+
side_effect = mock_config.get("side_effect")
76+
77+
patcher = patch(target, return_value=return_value, side_effect=side_effect)
78+
active_patch = patcher.__enter__()
79+
active_patches.append((patcher, active_patch))
80+
81+
return active_patches
82+
83+
def clear_mocks(self) -> None:
84+
"""Clear all registered mocks."""
85+
self.mocks.clear()
86+
87+
2888
class MultiProcessContext:
2989
def __init__(
3090
self,
@@ -111,6 +171,32 @@ def __init__(
111171
self._mp_init_mode: str = mp_init_mode
112172
logging.info(f"Using {self._mp_init_mode} for multiprocessing")
113173

174+
# Initialize MultiProcessMock
175+
self._mock_manager = MultiProcessMock()
176+
177+
def add_mock(
178+
self,
179+
target: str,
180+
return_value: Any = None,
181+
side_effect: Any = None,
182+
**kwargs: Any,
183+
) -> None:
184+
"""
185+
Add a new cross-process mock that will be applied during test execution.
186+
187+
Args:
188+
target: The target to mock (e.g., 'module.function')
189+
return_value: The return value for the mock
190+
side_effect: The side effect for the mock
191+
**kwargs: Additional arguments to pass to the mock
192+
"""
193+
self._mock_manager.add_mock(
194+
target=target,
195+
return_value=return_value,
196+
side_effect=side_effect,
197+
**kwargs,
198+
)
199+
114200
@seed_and_log
115201
def setUp(self) -> None:
116202
os.environ["MASTER_ADDR"] = str("localhost")
@@ -149,8 +235,10 @@ def _run_multi_process_test(
149235
for rank in range(world_size):
150236
kwargs["rank"] = rank
151237
kwargs["world_size"] = world_size
238+
kwargs["_mock_manager"] = self._mock_manager
152239
p = ctx.Process(
153-
target=callable,
240+
target=self._callable_wrapper_with_mocks,
241+
args=(callable,),
154242
kwargs=kwargs,
155243
)
156244
p.start()
@@ -176,9 +264,11 @@ def _run_multi_process_test_per_rank(
176264
kwargs = {}
177265
kwargs["rank"] = rank
178266
kwargs["world_size"] = world_size
267+
kwargs["_mock_manager"] = self._mock_manager
179268
kwargs.update(kwargs_per_rank[rank])
180269
p = ctx.Process(
181-
target=callable,
270+
target=self._callable_wrapper_with_mocks,
271+
args=(callable,),
182272
kwargs=kwargs,
183273
)
184274
p.start()
@@ -188,6 +278,36 @@ def _run_multi_process_test_per_rank(
188278
p.join()
189279
self.assertEqual(0, p.exitcode)
190280

281+
@staticmethod
282+
def _callable_wrapper_with_mocks(
283+
callable: Callable[..., None],
284+
_mock_manager: Optional[MultiProcessMock] = None,
285+
**kwargs: Any,
286+
) -> None:
287+
"""
288+
Wrapper that applies mocks before calling the target callable.
289+
290+
Args:
291+
callable: The function to call
292+
_mock_manager: Optional mock manager containing mocks to apply
293+
**kwargs: Additional keyword arguments to pass to the callable
294+
"""
295+
active_patches = []
296+
try:
297+
# Apply mocks if a mock manager is provided
298+
if _mock_manager is not None:
299+
active_patches = _mock_manager.apply_mocks()
300+
301+
# Remove _mock_manager from kwargs before calling the target
302+
kwargs.pop("_mock_manager", None)
303+
304+
# Call the actual test callable
305+
callable(**kwargs)
306+
finally:
307+
# Clean up all patches
308+
for patcher, _ in active_patches:
309+
patcher.__exit__(None, None, None)
310+
191311

192312
def _wrapper_func_for_multiprocessing(args): # pyre-ignore[2, 3]
193313
"""Wrapper function that unpacks arguments and calls the original func"""

0 commit comments

Comments
 (0)