1+ import random
12import functools
2- from py_paddle import swig_paddle as api
3- from py_paddle import DataProviderConverter
3+
44from paddle .trainer_config_helpers import *
55from paddle .trainer_config_helpers import inputs as ipts
6- import random
6+
7+ from .base import *
8+ from .. import DataProviderConverter
9+ from .. import swig_paddle as api
710
811__all__ = [
9- 'RunnerChainItem ' , 'Runner' , 'DeviceChainItem ' , 'CreateGradientMachine' ,
12+ 'RunnerItem ' , 'Runner' , 'DeviceItem ' , 'CreateGradientMachine' ,
1013 'RandomInitializeParams' , 'BasicLocalParameterUpdater' , 'network' ,
1114 'BasicTrainerDataProvider' , 'BasicDataProviderOps' ,
1215 'BasicGradientMachineTrainOps' , 'Counter' , 'BatchEvaluate' ,
@@ -89,113 +92,9 @@ def __optimize_graph_func__():
8992 return __impl__
9093
9194
92- class RunnerChainItem (object ):
93- def __init__ (self ):
94- pass
95-
96- def initialize (self , context , next_callback ):
97- next_callback (context )
98-
99- def finalize (self , context , next_callback ):
100- next_callback (context )
101-
102- def on_pass_begin (self , context , next_callback ):
103- next_callback (context )
104-
105- def on_pass_end (self , context , next_callback ):
106- next_callback (context )
107-
108- def on_batch_begin (self , context , next_callback ):
109- return next_callback (context )
110-
111- def on_batch_end (self , context , next_callback ):
112- return next_callback (context )
113-
114-
115- def default_next_callback (* args , ** kwargs ):
116- return False
117-
118-
119- class RunnerContext (object ):
120- pass
121-
122-
123- class Runner (object ):
124- def __init__ (self ):
125- self .chains = []
126-
127- self .begin_pass = None
128- self .end_pass = None
129- self .begin_batch = None
130- self .end_batch = None
131- self .finalize = None
132-
133- self .context = RunnerContext ()
134- self .context .runner = self
135-
136- def add_chain_item (self , item ):
137- assert isinstance (item , RunnerChainItem )
138- self .chains .append (item )
139-
140- def initialize (self , parent = None ):
141- if None not in [
142- self .begin_pass , self .end_pass , self .begin_batch ,
143- self .end_batch , self .finalize
144- ]:
145- return False
146- else :
147- assert len (self .chains ) != 0
148- actual_init = default_next_callback
149- self .begin_pass = default_next_callback
150- self .end_pass = default_next_callback
151- self .begin_batch = default_next_callback
152- self .end_batch = default_next_callback
153- self .finalize = default_next_callback
154-
155- for chain in reversed (self .chains ):
156- assert isinstance (chain , RunnerChainItem )
157- actual_init = functools .partial (
158- chain .initialize , next_callback = actual_init )
159- self .begin_pass = functools .partial (
160- chain .on_pass_begin , next_callback = self .begin_pass )
161- self .end_pass = functools .partial (
162- chain .on_pass_end , next_callback = self .end_pass )
163- self .begin_batch = functools .partial (
164- chain .on_batch_begin , next_callback = self .begin_batch )
165- self .end_batch = functools .partial (
166- chain .on_batch_end , next_callback = self .end_batch )
167- self .finalize = functools .partial (
168- chain .finalize , next_callback = self .finalize )
169-
170- if parent is not None :
171- self .context .parent = parent
172-
173- actual_init (self .context )
174- return True
175-
176- def run_one_pass (self , parent = None ):
177- if parent is not None :
178- self .context .parent = parent
179-
180- self .begin_pass (self .context )
181- exit_flag = False
182- while not exit_flag :
183- exit_flag = self .begin_batch (self .context )
184- if exit_flag :
185- break
186- exit_flag = self .end_batch (self .context )
187- self .end_pass (self .context )
188-
189- def __enter__ (self ):
190- self .initialize ()
191-
192- def __exit__ (self , exc_type , exc_val , exc_tb ):
193- self .finalize (self .context )
194-
195-
196- class DeviceChainItem (RunnerChainItem ):
95+ class DeviceItem (RunnerItem ):
19796 def __init__ (self , use_gpu = False , device_count = 4 ):
198- RunnerChainItem .__init__ (self )
97+ RunnerItem .__init__ (self )
19998 self .use_gpu = use_gpu
20099 self .device_count = device_count
201100
@@ -205,9 +104,9 @@ def initialize(self, context, next_callback):
205104 next_callback (context )
206105
207106
208- class CreateGradientMachine (RunnerChainItem ):
107+ class CreateGradientMachine (RunnerItem ):
209108 def __init__ (self , network ):
210- RunnerChainItem .__init__ (self )
109+ RunnerItem .__init__ (self )
211110 assert isinstance (network , NetworkConfig )
212111 self .__network__ = network
213112
@@ -237,9 +136,9 @@ def finalize(self, context, next_callback):
237136 next_callback (context )
238137
239138
240- class RandomInitializeParams (RunnerChainItem ):
139+ class RandomInitializeParams (RunnerItem ):
241140 def __init__ (self ):
242- RunnerChainItem .__init__ (self )
141+ RunnerItem .__init__ (self )
243142
244143 def initialize (self , context , next_callback ):
245144 assert hasattr (context , 'gradient_machine' ) and isinstance (
@@ -248,12 +147,12 @@ def initialize(self, context, next_callback):
248147 next_callback (context )
249148
250149
251- class BasicLocalParameterUpdaterOps (RunnerChainItem ):
150+ class BasicLocalParameterUpdaterOps (RunnerItem ):
252151 def __init__ (self ,
253152 updater_name = 'updater' ,
254153 batch_size_name = 'current_batch_size' ,
255154 cost_name = 'current_cost' ):
256- RunnerChainItem .__init__ (self )
155+ RunnerItem .__init__ (self )
257156 self .__updater_name__ = updater_name
258157 self .__batch_size_name__ = batch_size_name
259158 self .__cost_name__ = cost_name
@@ -311,9 +210,9 @@ def initialize(self, context, next_callback):
311210 next_callback (context )
312211
313212
314- class BasicGradientMachineTrainOps (RunnerChainItem ):
213+ class BasicGradientMachineTrainOps (RunnerItem ):
315214 def __init__ (self ):
316- RunnerChainItem .__init__ (self )
215+ RunnerItem .__init__ (self )
317216 self .__out_args__ = api .Arguments .createArguments (0 )
318217
319218 def on_batch_begin (self , context , next_callback ):
@@ -334,9 +233,9 @@ def on_batch_begin(self, context, next_callback):
334233 return next_callback (context )
335234
336235
337- class Counter (RunnerChainItem ):
236+ class Counter (RunnerItem ):
338237 def __init__ (self ):
339- RunnerChainItem .__init__ (self )
238+ RunnerItem .__init__ (self )
340239
341240 def initialize (self , context , next_callback ):
342241 context .current_pass_id = 0
@@ -353,9 +252,9 @@ def on_pass_end(self, context, next_callback):
353252 context .current_pass_id += 1
354253
355254
356- class BaseEvaluate (RunnerChainItem ):
255+ class BaseEvaluate (RunnerItem ):
357256 def __init__ (self , prefix = None ):
358- RunnerChainItem .__init__ (self )
257+ RunnerItem .__init__ (self )
359258 self .__evaluator__ = None
360259 if prefix is None :
361260 prefix = ''
@@ -409,9 +308,9 @@ def on_pass_end(self, context, next_callback):
409308 self .__evaluator__ .finish ()
410309
411310
412- class BasicGradientMachineTestOps (RunnerChainItem ):
311+ class BasicGradientMachineTestOps (RunnerItem ):
413312 def __init__ (self ):
414- RunnerChainItem .__init__ (self )
313+ RunnerItem .__init__ (self )
415314 self .__out_args__ = api .Arguments .createArguments (0 )
416315
417316 def on_pass_begin (self , context , next_callback ):
@@ -428,9 +327,9 @@ def on_pass_end(self, context, next_callback):
428327 next_callback (context )
429328
430329
431- class InheritGradientMachineUpdater (RunnerChainItem ):
330+ class InheritGradientMachineUpdater (RunnerItem ):
432331 def __init__ (self ):
433- RunnerChainItem .__init__ (self )
332+ RunnerItem .__init__ (self )
434333
435334 def initialize (self , context , next_callback ):
436335 if context .parent is not None :
@@ -449,18 +348,18 @@ def on_batch_begin(self, context, next_callback):
449348 return next_callback (context )
450349
451350
452- class TestOnPassEnd (RunnerChainItem ):
351+ class TestOnPassEnd (RunnerItem ):
453352 def __init__ (self , ** kwargs ):
454- RunnerChainItem .__init__ (self )
353+ RunnerItem .__init__ (self )
455354 self .__test_runner__ = Runner ()
456- self .__test_runner__ .add_chain_item (InheritGradientMachineUpdater ())
457- self .__test_runner__ .add_chain_item (BasicTestDataProvider (** kwargs ))
458- self .__test_runner__ .add_chain_item (BasicGradientMachineTestOps ())
459- self .__test_runner__ .add_chain_item (PassEvaluate (prefix = 'Test: ' ))
355+ self .__test_runner__ .add_item (InheritGradientMachineUpdater ())
356+ self .__test_runner__ .add_item (BasicTestDataProvider (** kwargs ))
357+ self .__test_runner__ .add_item (BasicGradientMachineTestOps ())
358+ self .__test_runner__ .add_item (PassEvaluate (prefix = 'Test: ' ))
460359
461360 def initialize (self , context , next_callback ):
462361 next_callback (context )
463- self .__test_runner__ .initialize (context )
362+ self .__test_runner__ .__initialize__ (context )
464363
465364 def on_pass_end (self , context , next_callback ):
466365 self .__test_runner__ .run_one_pass (parent = context )
@@ -515,9 +414,9 @@ def next(self):
515414 raise StopIteration
516415
517416
518- class BasicDataProviderOps (RunnerChainItem ):
417+ class BasicDataProviderOps (RunnerItem ):
519418 def __init__ (self , provider_name = 'data_provider' ):
520- RunnerChainItem .__init__ (self )
419+ RunnerItem .__init__ (self )
521420 self .__provider_name__ = provider_name
522421
523422 def __get_provider__ (self , context ):
@@ -575,9 +474,9 @@ def initialize(self, context, next_callback):
575474BasicTestDataProvider = data_provider_creator (False )
576475
577476
578- class SaveParamsOnPassEnd (RunnerChainItem ):
477+ class SaveParamsOnPassEnd (RunnerItem ):
579478 def __init__ (self ):
580- RunnerChainItem .__init__ (self )
479+ RunnerItem .__init__ (self )
581480
582481 def on_pass_end (self , context , next_callback ):
583482 context .updater .catchUpWith ()
@@ -591,12 +490,12 @@ def on_pass_end(self, context, next_callback):
591490class RunnerBuilder (object ):
592491 def __init__ (self , network , use_gpu = False , device_count = 1 ):
593492 self .__runner__ = Runner ()
594- self .__runner__ .add_chain_item (Counter ())
493+ self .__runner__ .add_item (Counter ())
595494 self .__network__ = network
596- self .__runner__ .add_chain_item (
597- DeviceChainItem (
495+ self .__runner__ .add_item (
496+ DeviceItem (
598497 use_gpu = use_gpu , device_count = device_count ))
599- self .__runner__ .add_chain_item (
498+ self .__runner__ .add_item (
600499 CreateGradientMachine (network = self .__network__ ))
601500
602501 self .__train_data__ = None
@@ -605,7 +504,7 @@ def __init__(self, network, use_gpu=False, device_count=1):
605504 self .__evaluate__ = []
606505
607506 def with_std_random_init_params (self ):
608- self .__runner__ .add_chain_item (RandomInitializeParams ())
507+ self .__runner__ .add_item (RandomInitializeParams ())
609508 return self
610509
611510 def with_train_data (self , method , file_list , batch_size = None , ** kwargs ):
@@ -659,9 +558,9 @@ def with_std_local_trainer(self, **kwargs):
659558 ).with_batch_evaluator ().with_std_param_saver ()
660559
661560 def build (self ):
662- self .__runner__ .add_chain_item (self .__train_data__ )
663- self .__runner__ .add_chain_item (self .__updater__ )
664- self .__runner__ .add_chain_item (self .__gradient_machine__ )
561+ self .__runner__ .add_item (self .__train_data__ )
562+ self .__runner__ .add_item (self .__updater__ )
563+ self .__runner__ .add_item (self .__gradient_machine__ )
665564 for each in self .__evaluate__ :
666- self .__runner__ .add_chain_item (each )
565+ self .__runner__ .add_item (each )
667566 return self .__runner__
0 commit comments