Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/network/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ type Conn interface {
//
// You very likely do not need to use this method.
As(target any) bool

// Context returns a context that is cancelled when the connection is closed.
// This can be used to clean up resources associated with the connection
// and to signal early cancellation of work that depends on the connection.
Context() context.Context
}

// ConnectionState holds information about the connection.
Expand Down
23 changes: 23 additions & 0 deletions p2p/net/mock/mock_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strconv"
"sync"
"sync/atomic"
"time"

ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
Expand Down Expand Up @@ -193,3 +194,25 @@ func (c *conn) Scope() network.ConnScope {
func (c *conn) CloseWithError(_ network.ConnErrorCode) error {
return c.Close()
}

// Context returns a context that is cancelled when the connection is closed.
func (c *conn) Context() context.Context {
// For mock connections, we return a context that is cancelled when the connection is closed
// This is a simplified implementation for testing purposes
ctx, cancel := context.WithCancel(context.Background())
go func() {
// Wait for the connection to be closed by checking periodically
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if c.IsClosed() {
cancel()
return
}
}
}
}()
return ctx
}
11 changes: 7 additions & 4 deletions p2p/net/swarm/swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,14 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
isLimited := stat.Limited

// Wrap and register the connection.
ctx, cancel := context.WithCancel(context.Background())
c := &Conn{
conn: tc,
swarm: s,
stat: stat,
id: s.nextConnID.Add(1),
conn: tc,
swarm: s,
stat: stat,
id: s.nextConnID.Add(1),
ctx: ctx,
cancel: cancel,
}

// we ONLY check upgraded connections here so we can send them a Disconnect message.
Expand Down
12 changes: 12 additions & 0 deletions p2p/net/swarm/swarm_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ type Conn struct {
}

stat network.ConnStats

// Context and cancel function for connection lifecycle management
ctx context.Context
cancel context.CancelFunc
}

var _ network.Conn = &Conn{}
Expand Down Expand Up @@ -78,6 +82,9 @@ func (c *Conn) CloseWithError(errCode network.ConnErrorCode) error {
func (c *Conn) doClose(errCode network.ConnErrorCode) {
c.swarm.removeConn(c)

// Cancel the context to signal that the connection is closed
c.cancel()

// Prevent new streams from opening.
c.streams.Lock()
streams := c.streams.m
Expand Down Expand Up @@ -297,3 +304,8 @@ func (c *Conn) GetStreams() []network.Stream {
func (c *Conn) Scope() network.ConnScope {
return c.conn.Scope()
}

// Context returns a context that is cancelled when the connection is closed.
func (c *Conn) Context() context.Context {
return c.ctx
}
248 changes: 248 additions & 0 deletions p2p/net/swarm/swarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,251 @@ func TestAddCertHashes(t *testing.T) {
}
}
}

func TestConnContext(t *testing.T) {
// Test that the context is cancelled when the connection is closed
t.Run("ContextCancelledOnClose", func(t *testing.T) {
s1 := GenSwarm(t, OptDisableReuseport)
s2 := GenSwarm(t, OptDisableReuseport)
defer s1.Close()
defer s2.Close()

// Connect the swarms
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
_, err := s1.DialPeer(context.Background(), s2.LocalPeer())
require.NoError(t, err)

// Get the connection
conns := s1.ConnsToPeer(s2.LocalPeer())
require.Len(t, conns, 1)
conn := conns[0]

// Get the context
ctx := conn.Context()
require.NotNil(t, ctx)

// Context should not be cancelled initially
select {
case <-ctx.Done():
t.Fatal("context should not be cancelled initially")
default:
}

// Close the connection
err = conn.Close()
require.NoError(t, err)

// Context should be cancelled now
select {
case <-ctx.Done():
// Expected
case <-time.After(time.Second):
t.Fatal("context should be cancelled after connection close")
}

// Verify the context error
require.Error(t, ctx.Err())
require.Equal(t, context.Canceled, ctx.Err())
})

// Test that the context is cancelled when the connection is closed with error
t.Run("ContextCancelledOnCloseWithError", func(t *testing.T) {
s1 := GenSwarm(t, OptDisableReuseport)
s2 := GenSwarm(t, OptDisableReuseport)
defer s1.Close()
defer s2.Close()

// Connect the swarms
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
_, err := s1.DialPeer(context.Background(), s2.LocalPeer())
require.NoError(t, err)

// Get the connection
conns := s1.ConnsToPeer(s2.LocalPeer())
require.Len(t, conns, 1)
conn := conns[0]

// Get the context
ctx := conn.Context()
require.NotNil(t, ctx)

// Context should not be cancelled initially
select {
case <-ctx.Done():
t.Fatal("context should not be cancelled initially")
default:
}

// Close the connection with error
err = conn.CloseWithError(network.ConnShutdown)
require.NoError(t, err)

// Context should be cancelled now
select {
case <-ctx.Done():
// Expected
case <-time.After(time.Second):
t.Fatal("context should be cancelled after connection close with error")
}

// Verify the context error
require.Error(t, ctx.Err())
require.Equal(t, context.Canceled, ctx.Err())
})

// Test that the context can be used with context.AfterFunc
t.Run("ContextAfterFunc", func(t *testing.T) {
s1 := GenSwarm(t, OptDisableReuseport)
s2 := GenSwarm(t, OptDisableReuseport)
defer s1.Close()
defer s2.Close()

// Connect the swarms
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
_, err := s1.DialPeer(context.Background(), s2.LocalPeer())
require.NoError(t, err)

// Get the connection
conns := s1.ConnsToPeer(s2.LocalPeer())
require.Len(t, conns, 1)
conn := conns[0]

// Get the context
ctx := conn.Context()
require.NotNil(t, ctx)

// Use context.AfterFunc to clean up resources
var cleanupCalled bool
var cleanupMutex sync.Mutex
context.AfterFunc(ctx, func() {
cleanupMutex.Lock()
cleanupCalled = true
cleanupMutex.Unlock()
})

// Close the connection
err = conn.Close()
require.NoError(t, err)

// Wait for the cleanup function to be called
require.Eventually(t, func() bool {
cleanupMutex.Lock()
defer cleanupMutex.Unlock()
return cleanupCalled
}, time.Second, 10*time.Millisecond, "cleanup function should be called")

// Verify the context error
require.Error(t, ctx.Err())
require.Equal(t, context.Canceled, ctx.Err())
})

// Test that multiple contexts from the same connection are all cancelled
t.Run("MultipleContextsCancelled", func(t *testing.T) {
s1 := GenSwarm(t, OptDisableReuseport)
s2 := GenSwarm(t, OptDisableReuseport)
defer s1.Close()
defer s2.Close()

// Connect the swarms
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
_, err := s1.DialPeer(context.Background(), s2.LocalPeer())
require.NoError(t, err)

// Get the connection
conns := s1.ConnsToPeer(s2.LocalPeer())
require.Len(t, conns, 1)
conn := conns[0]

// Get multiple contexts
ctx1 := conn.Context()
ctx2 := conn.Context()
require.NotNil(t, ctx1)
require.NotNil(t, ctx2)

// Both contexts should be the same instance
require.Equal(t, ctx1, ctx2)

// Contexts should not be cancelled initially
select {
case <-ctx1.Done():
t.Fatal("context1 should not be cancelled initially")
case <-ctx2.Done():
t.Fatal("context2 should not be cancelled initially")
default:
}

// Close the connection
err = conn.Close()
require.NoError(t, err)

// Both contexts should be cancelled now
select {
case <-ctx1.Done():
// Expected
case <-time.After(time.Second):
t.Fatal("context1 should be cancelled after connection close")
}

select {
case <-ctx2.Done():
// Expected
case <-time.After(time.Second):
t.Fatal("context2 should be cancelled after connection close")
}

// Verify both context errors
require.Error(t, ctx1.Err())
require.Equal(t, context.Canceled, ctx1.Err())
require.Error(t, ctx2.Err())
require.Equal(t, context.Canceled, ctx2.Err())
})

// Test that the context is cancelled when the remote peer closes the connection
t.Run("ContextCancelledOnRemoteClose", func(t *testing.T) {
s1 := GenSwarm(t, OptDisableReuseport)
s2 := GenSwarm(t, OptDisableReuseport)
defer s1.Close()
defer s2.Close()

// Connect the swarms
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
_, err := s1.DialPeer(context.Background(), s2.LocalPeer())
require.NoError(t, err)

// Get the connection
conns := s1.ConnsToPeer(s2.LocalPeer())
require.Len(t, conns, 1)
conn := conns[0]

// Get the context
ctx := conn.Context()
require.NotNil(t, ctx)

// Context should not be cancelled initially
select {
case <-ctx.Done():
t.Fatal("context should not be cancelled initially")
default:
}

// Close the remote swarm (simulating remote peer closing)
s2.Close()

// Wait for the connection to be closed
require.Eventually(t, func() bool {
return conn.IsClosed()
}, time.Second, 10*time.Millisecond, "connection should be closed")

// Context should be cancelled now
select {
case <-ctx.Done():
// Expected
case <-time.After(time.Second):
t.Fatal("context should be cancelled after remote close")
}

// Verify the context error
require.Error(t, ctx.Err())
require.Equal(t, context.Canceled, ctx.Err())
})
}