@@ -364,14 +364,56 @@ def _sample(self, X, y):
364364
365365 prev_len = y_ .shape [0 ]
366366 if self .return_indices :
367- X_ , y_ , idx_ = self .enn_ .fit_sample (X_ , y_ )
368- idx_under = idx_under [idx_ ]
367+ X_enn , y_enn , idx_enn = self .enn_ .fit_sample (X_ , y_ )
369368 else :
370- X_ , y_ = self .enn_ .fit_sample (X_ , y_ )
371-
372- if prev_len == y_ .shape [0 ]:
369+ X_enn , y_enn = self .enn_ .fit_sample (X_ , y_ )
370+
371+ # Check the stopping criterion
372+ # 1. If there is no changes for the vector y
373+ # 2. If the number of samples in the other class become inferior to
374+ # the number of samples in the majority class
375+ # 3. If one of the class is disappearing
376+
377+ # Case 1
378+ b_conv = (prev_len == y_enn .shape [0 ])
379+
380+ # Case 2
381+ stats_enn = Counter (y_enn )
382+ self .logger .debug ('Current ENN stats: %s' , stats_enn )
383+ # Get the number of samples in the non-minority classes
384+ count_non_min = np .array ([val for val , key
385+ in zip (stats_enn .itervalues (),
386+ stats_enn .iterkeys ())
387+ if key != self .min_c_ ])
388+ self .logger .debug ('Number of samples in the non-majority'
389+ ' classes: %s' , count_non_min )
390+ # Check the minority stop to be the minority
391+ b_min_bec_maj = np .any (count_non_min < self .stats_c_ [self .min_c_ ])
392+
393+ # Case 3
394+ b_remove_maj_class = (len (stats_enn ) < len (self .stats_c_ ))
395+
396+ if b_conv or b_min_bec_maj or b_remove_maj_class :
397+ # If this is a normal convergence, get the last data
398+ if b_conv :
399+ if self .return_indices :
400+ X_ , y_ , = X_enn , y_enn
401+ idx_under = idx_under [idx_enn ]
402+ else :
403+ X_ , y_ , = X_enn , y_enn
404+ # Log the variables to explain the stop of the algorithm
405+ self .logger .debug ('RENN converged: %s' , b_conv )
406+ self .logger .debug ('RENN minority become majority: %s' ,
407+ b_min_bec_maj )
408+ self .logger .debug ('RENN remove one class: %s' ,
409+ b_remove_maj_class )
373410 break
374411
412+ # Update the data for the next iteration
413+ X_ , y_ , = X_enn , y_enn
414+ if self .return_indices :
415+ idx_under = idx_under [idx_enn ]
416+
375417 self .logger .info ('Under-sampling performed: %s' , Counter (y_ ))
376418
377419 X_resampled , y_resampled = X_ , y_
0 commit comments