@@ -68,32 +68,53 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
6868 _check_unpickable_fn (key_value_fn )
6969 self .key_value_fn = key_value_fn # type: ignore[assignment]
7070 self ._map = None
71+ self ._itr = None
72+ self ._depleted = False
7173
7274 def _load_map (self ):
73- self ._map = {}
74- for d in self .datapipe :
75- inp = d if self .key_value_fn is None else self .key_value_fn (d )
75+ if self ._map is None :
76+ self ._map = {}
77+ self ._itr = iter (self .datapipe )
78+ while not self ._depleted :
7679 try :
77- length = len (inp )
78- except TypeError :
79- raise TypeError (f"Cannot convert dictionary update element { type (inp )} ({ inp } ) to a sequence" )
80- if length != 2 :
81- raise ValueError (f"dictionary update sequence element has length { length } , 2 is required" )
82- key , value = inp
83- if key in self ._map :
84- warnings .warn (f"Found duplicate key { key } . Please check your `key_value_fn`" )
85- self ._map [key ] = value
80+ self ._load_next_item ()
81+ except StopIteration :
82+ self ._depleted = True
8683
8784 def __getitem__ (self , index ):
8885 try :
8986 if self ._map is None :
90- self ._load_map ()
91- return self ._map [index ] # type: ignore[index]
87+ self ._map = {}
88+ self ._itr = iter (self .datapipe )
89+ raise KeyError
90+ return self ._map [index ]
9291 except KeyError :
92+ while not self ._depleted :
93+ try :
94+ key , value = self ._load_next_item ()
95+ if key == index :
96+ return value
97+ except StopIteration :
98+ self ._depleted = True
9399 raise IndexError (f"Index { index } is invalid for IterToMapConverter." )
94100
101+ def _load_next_item (self ):
102+ elem = next (self ._itr )
103+ inp = elem if self .key_value_fn is None else self .key_value_fn (elem )
104+ try :
105+ length = len (inp )
106+ except TypeError :
107+ raise TypeError (f"Cannot convert dictionary update element { type (inp )} ({ inp } ) to a sequence" )
108+ if length != 2 :
109+ raise ValueError (f"dictionary update sequence element has length { length } , 2 is required" )
110+ key , value = inp
111+ if key in self ._map :
112+ warnings .warn (f"Found duplicate key { key } . Please check your `key_value_fn`" )
113+ self ._map [key ] = value
114+ return key , value
115+
95116 def __len__ (self ):
96- if self ._map is not None :
117+ if self ._depleted :
97118 return len (self ._map ) # type: ignore[arg-type]
98119 try :
99120 return len (self .datapipe )
0 commit comments