Skip to content

Commit fedbccc

Browse files
authored
fix(BatchPublishing): Make topic.AddToBatch threadsafe (#622)
topic.Publish is already thread safe. topic.AddToBatch should strive to follow similar semantics. Looking at how this would integrate with Prysm, they use separate goroutines per message they'd like to batch.
1 parent 3f89e43 commit fedbccc

File tree

4 files changed

+88
-55
lines changed

4 files changed

+88
-55
lines changed

gossipsub_test.go

Lines changed: 70 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3682,66 +3682,85 @@ func BenchmarkRoundRobinMessageIDScheduler(b *testing.B) {
36823682
}
36833683

36843684
func TestMessageBatchPublish(t *testing.T) {
3685-
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
3686-
defer cancel()
3687-
hosts := getDefaultHosts(t, 20)
3688-
3689-
msgIDFn := func(msg *pb.Message) string {
3690-
hdr := string(msg.Data[0:16])
3691-
msgID := strings.SplitN(hdr, " ", 2)
3692-
return msgID[0]
3693-
}
3694-
const numMessages = 100
3695-
// +8 to account for the gossiping overhead
3696-
psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8), WithValidateQueueSize(numMessages+8))
3685+
concurrentAdds := []bool{false, true}
3686+
for _, concurrentAdd := range concurrentAdds {
3687+
t.Run(fmt.Sprintf("WithConcurrentAdd=%v", concurrentAdd), func(t *testing.T) {
3688+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
3689+
defer cancel()
3690+
hosts := getDefaultHosts(t, 20)
3691+
3692+
msgIDFn := func(msg *pb.Message) string {
3693+
hdr := string(msg.Data[0:16])
3694+
msgID := strings.SplitN(hdr, " ", 2)
3695+
return msgID[0]
3696+
}
3697+
const numMessages = 100
3698+
// +8 to account for the gossiping overhead
3699+
psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8), WithValidateQueueSize(numMessages+8))
3700+
3701+
var topics []*Topic
3702+
var msgs []*Subscription
3703+
for _, ps := range psubs {
3704+
topic, err := ps.Join("foobar")
3705+
if err != nil {
3706+
t.Fatal(err)
3707+
}
3708+
topics = append(topics, topic)
36973709

3698-
var topics []*Topic
3699-
var msgs []*Subscription
3700-
for _, ps := range psubs {
3701-
topic, err := ps.Join("foobar")
3702-
if err != nil {
3703-
t.Fatal(err)
3704-
}
3705-
topics = append(topics, topic)
3710+
subch, err := topic.Subscribe(WithBufferSize(numMessages + 8))
3711+
if err != nil {
3712+
t.Fatal(err)
3713+
}
37063714

3707-
subch, err := topic.Subscribe(WithBufferSize(numMessages + 8))
3708-
if err != nil {
3709-
t.Fatal(err)
3710-
}
3715+
msgs = append(msgs, subch)
3716+
}
37113717

3712-
msgs = append(msgs, subch)
3713-
}
3718+
sparseConnect(t, hosts)
37143719

3715-
sparseConnect(t, hosts)
3720+
// wait for heartbeats to build mesh
3721+
time.Sleep(time.Second * 2)
37163722

3717-
// wait for heartbeats to build mesh
3718-
time.Sleep(time.Second * 2)
3719-
3720-
var batch MessageBatch
3721-
for i := 0; i < numMessages; i++ {
3722-
msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i))
3723-
err := topics[0].AddToBatch(ctx, &batch, msg)
3724-
if err != nil {
3725-
t.Fatal(err)
3726-
}
3727-
}
3728-
err := psubs[0].PublishBatch(&batch)
3729-
if err != nil {
3730-
t.Fatal(err)
3731-
}
3732-
3733-
for range numMessages {
3734-
for _, sub := range msgs {
3735-
got, err := sub.Next(ctx)
3723+
var batch MessageBatch
3724+
var wg sync.WaitGroup
3725+
for i := 0; i < numMessages; i++ {
3726+
msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i))
3727+
if concurrentAdd {
3728+
wg.Add(1)
3729+
go func() {
3730+
defer wg.Done()
3731+
err := topics[0].AddToBatch(ctx, &batch, msg)
3732+
if err != nil {
3733+
t.Log(err)
3734+
t.Fail()
3735+
}
3736+
}()
3737+
} else {
3738+
err := topics[0].AddToBatch(ctx, &batch, msg)
3739+
if err != nil {
3740+
t.Fatal(err)
3741+
}
3742+
}
3743+
}
3744+
wg.Wait()
3745+
err := psubs[0].PublishBatch(&batch)
37363746
if err != nil {
37373747
t.Fatal(err)
37383748
}
3739-
id := msgIDFn(got.Message)
3740-
expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id))
3741-
if !bytes.Equal(expected, got.Data) {
3742-
t.Fatal("got wrong message!")
3749+
3750+
for range numMessages {
3751+
for _, sub := range msgs {
3752+
got, err := sub.Next(ctx)
3753+
if err != nil {
3754+
t.Fatal(err)
3755+
}
3756+
id := msgIDFn(got.Message)
3757+
expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id))
3758+
if !bytes.Equal(expected, got.Data) {
3759+
t.Fatal("got wrong message!")
3760+
}
3761+
}
37433762
}
3744-
}
3763+
})
37453764
}
37463765
}
37473766

messagebatch.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package pubsub
22

33
import (
44
"iter"
5+
"sync"
56

67
"github.com/libp2p/go-libp2p/core/peer"
78
)
@@ -10,9 +11,24 @@ import (
1011
// once. This allows the Scheduler to define an order for outgoing RPCs.
1112
// This helps bandwidth constrained peers.
1213
type MessageBatch struct {
14+
mu sync.Mutex
1315
messages []*Message
1416
}
1517

18+
func (mb *MessageBatch) add(msg *Message) {
19+
mb.mu.Lock()
20+
defer mb.mu.Unlock()
21+
mb.messages = append(mb.messages, msg)
22+
}
23+
24+
func (mb *MessageBatch) take() []*Message {
25+
mb.mu.Lock()
26+
defer mb.mu.Unlock()
27+
messages := mb.messages
28+
mb.messages = nil
29+
return messages
30+
}
31+
1632
type messageBatchAndPublishOptions struct {
1733
messages []*Message
1834
opts *BatchPublishOptions

pubsub.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,12 +1600,10 @@ func (p *PubSub) PublishBatch(batch *MessageBatch, opts ...BatchPubOpt) error {
16001600
setDefaultBatchPublishOptions(publishOptions)
16011601

16021602
p.sendMessageBatch <- messageBatchAndPublishOptions{
1603-
messages: batch.messages,
1603+
messages: batch.take(),
16041604
opts: publishOptions,
16051605
}
16061606

1607-
// Clear the batch's messages in case a user reuses the same batch object
1608-
batch.messages = nil
16091607
return nil
16101608
}
16111609

topic.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func (t *Topic) AddToBatch(ctx context.Context, batch *MessageBatch, data []byte
257257
}
258258
return err
259259
}
260-
batch.messages = append(batch.messages, msg)
260+
batch.add(msg)
261261
return nil
262262
}
263263

0 commit comments

Comments
 (0)