@@ -145,9 +145,10 @@ def emit_unpack_instruction(self, *, loop_indices=None):
145145
146146class GlobalPack (Pack ):
147147
148- def __init__ (self , outer , access ):
148+ def __init__ (self , outer , access , init_with_zero = False ):
149149 self .outer = outer
150150 self .access = access
151+ self .init_with_zero = init_with_zero
151152
152153 def kernel_arg (self , loop_indices = None ):
153154 pack = self .pack (loop_indices )
@@ -169,11 +170,15 @@ def pack(self, loop_indices=None):
169170 # vectorisation loop transformations privatise these reduction
170171 # variables. The extra memory movement cost is minimal.
171172 loop_indices = self .pick_loop_indices (* loop_indices )
172- if self .access in {INC , WRITE }:
173+ if self .init_with_zero :
174+ also_zero = {MIN , MAX }
175+ else :
176+ also_zero = set ()
177+ if self .access in {INC , WRITE } | also_zero :
173178 val = Zero ((), self .outer .dtype )
174179 multiindex = MultiIndex (* (Index (e ) for e in shape ))
175180 self ._pack = Materialise (PackInst (loop_indices ), val , multiindex )
176- elif self .access in {READ , RW , MIN , MAX }:
181+ elif self .access in {READ , RW , MIN , MAX } - also_zero :
177182 multiindex = MultiIndex (* (Index (e ) for e in shape ))
178183 expr = Indexed (self .outer , multiindex )
179184 self ._pack = Materialise (PackInst (loop_indices ), expr , multiindex )
@@ -203,13 +208,15 @@ def emit_unpack_instruction(self, *, loop_indices=None):
203208
204209class DatPack (Pack ):
205210 def __init__ (self , outer , access , map_ = None , interior_horizontal = False ,
206- view_index = None , layer_bounds = None ):
211+ view_index = None , layer_bounds = None ,
212+ init_with_zero = False ):
207213 self .outer = outer
208214 self .map_ = map_
209215 self .access = access
210216 self .interior_horizontal = interior_horizontal
211217 self .view_index = view_index
212218 self .layer_bounds = layer_bounds
219+ self .init_with_zero = init_with_zero
213220
214221 def _mask (self , map_ ):
215222 """Override this if the map_ needs a masking condition."""
@@ -245,11 +252,15 @@ def pack(self, loop_indices=None):
245252 if self .view_index is None :
246253 shape = shape + self .outer .shape [1 :]
247254
248- if self .access in {INC , WRITE }:
255+ if self .init_with_zero :
256+ also_zero = {MIN , MAX }
257+ else :
258+ also_zero = set ()
259+ if self .access in {INC , WRITE } | also_zero :
249260 val = Zero ((), self .outer .dtype )
250261 multiindex = MultiIndex (* (Index (e ) for e in shape ))
251262 self ._pack = Materialise (PackInst (), val , multiindex )
252- elif self .access in {READ , RW , MIN , MAX }:
263+ elif self .access in {READ , RW , MIN , MAX } - also_zero :
253264 multiindex = MultiIndex (* (Index (e ) for e in shape ))
254265 expr , mask = self ._rvalue (multiindex , loop_indices = loop_indices )
255266 if mask is not None :
@@ -577,8 +588,9 @@ def emit_unpack_instruction(self, *,
577588
578589class WrapperBuilder (object ):
579590
580- def __init__ (self , * , iterset , iteration_region = None , single_cell = False ,
591+ def __init__ (self , * , kernel , iterset , iteration_region = None , single_cell = False ,
581592 pass_layer_to_kernel = False , forward_arg_types = ()):
593+ self .kernel = kernel
582594 self .arguments = []
583595 self .argument_accesses = []
584596 self .packed_args = []
@@ -593,6 +605,10 @@ def __init__(self, *, iterset, iteration_region=None, single_cell=False,
593605 self .single_cell = single_cell
594606 self .forward_arguments = tuple (Argument ((), fa , pfx = "farg" ) for fa in forward_arg_types )
595607
608+ @property
609+ def requires_zeroed_output_arguments (self ):
610+ return self .kernel .requires_zeroed_output_arguments
611+
596612 @property
597613 def subset (self ):
598614 return isinstance (self .iterset , Subset )
@@ -605,9 +621,6 @@ def extruded(self):
605621 def constant_layers (self ):
606622 return self .extruded and self .iterset .constant_layers
607623
608- def set_kernel (self , kernel ):
609- self .kernel = kernel
610-
611624 @cached_property
612625 def loop_extents (self ):
613626 return (Argument ((), IntType , name = "start" ),
@@ -722,7 +735,8 @@ def add_argument(self, arg):
722735 shape = (None , * a .data .shape [1 :])
723736 argument = Argument (shape , a .data .dtype , pfx = "mdat" )
724737 packs .append (a .data .pack (argument , arg .access , self .map_ (a .map , unroll = a .unroll_map ),
725- interior_horizontal = interior_horizontal ))
738+ interior_horizontal = interior_horizontal ,
739+ init_with_zero = self .requires_zeroed_output_arguments ))
726740 self .arguments .append (argument )
727741 pack = MixedDatPack (packs , arg .access , arg .dtype , interior_horizontal = interior_horizontal )
728742 self .packed_args .append (pack )
@@ -740,15 +754,17 @@ def add_argument(self, arg):
740754 pfx = "dat" )
741755 pack = arg .data .pack (argument , arg .access , self .map_ (arg .map , unroll = arg .unroll_map ),
742756 interior_horizontal = interior_horizontal ,
743- view_index = view_index )
757+ view_index = view_index ,
758+ init_with_zero = self .requires_zeroed_output_arguments )
744759 self .arguments .append (argument )
745760 self .packed_args .append (pack )
746761 self .argument_accesses .append (arg .access )
747762 elif arg ._is_global :
748763 argument = Argument (arg .data .dim ,
749764 arg .data .dtype ,
750765 pfx = "glob" )
751- pack = GlobalPack (argument , arg .access )
766+ pack = GlobalPack (argument , arg .access ,
767+ init_with_zero = self .requires_zeroed_output_arguments )
752768 self .arguments .append (argument )
753769 self .packed_args .append (pack )
754770 self .argument_accesses .append (arg .access )
0 commit comments