-
Notifications
You must be signed in to change notification settings - Fork 170
Description
🚀 The feature
Feature request: expose a pin_memory_map parameter in the PinMemory node, which defaults to the current choice (pin_memory from torch):
from torch.utils.data._utils.pin_memory import pin_memory
class PinMemory(BaseNode[T]):
def __init__(
self,
source: BaseNode[T],
pin_memory_device: str = "",
snapshot_frequency: int = 1,
pin_memory_map: Callable[[T, DeviceType | None], T] = pin_memory
): ...The same parameter and default needs to appear in _pin_memory_loop and would override the function called here:
https://github.com/pytorch/data/blob/dbf04a9108d38066efa60ce24bdcb8190a51c0bd/torchdata/nodes/pin_memory.py#L81
Motivation, pitch
The current pytorch implementation of pin_memory only partially allows custom objects to implement a 'pin memory interface':
https://github.com/pytorch/pytorch/blob/50d4698ac8c12ad8399773aa157d25316c7c345e/torch/utils/data/_utils/pin_memory.py#L108
Note that the device is not being passed when pin_memory is called on the object. This would allow objects to implement their own def pin_memory(self, device: torch.device | None = None) -> None which is then used by the PinMemory node. So one could pass a map e.g.:
@runtime_checkable
class SupportsPinMemory(Protocol):
def pin_memory(self, device: torch.device | None = None) -> Self: ...
def pin_memory_custom(data: Any, device: torch.device | None) -> Any:
if isinstance(data, SupportsPinMemory):
return data.pin_memory(device=device)
# Otherwise default to pytorch pin memory
return pin_memory(data, device)
node = PinMemory(source=other_node, pin_memory_map=pin_memory_custom)Alternatives
Of course this can be done now with a custom Mapper, but my understanding is the reimplmented _pin_memory_loop that the PinMemory node uses plays nicely with the rest of the nodes in a pipeline without consuming all CPU cores.
Additional context
No response