Skip to content

Commit 892f6e9

Browse files
improve the packet API (#30)
1 parent 25cbbbc commit 892f6e9

File tree

5 files changed

+44
-52
lines changed

5 files changed

+44
-52
lines changed

conn.go

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ func (c *Conn) writeToStream() error {
243243
}
244244
}
245245

246-
func (c *Conn) Read(b []byte) (n int, err error) {
246+
func (c *Conn) ReadPacket(b []byte) (n int, err error) {
247247
start:
248248
data, err := c.str.ReceiveDatagram(context.Background())
249249
if err != nil {
@@ -339,34 +339,31 @@ func (c *Conn) handleIncomingProxiedPacket(data []byte) error {
339339
return nil
340340
}
341341

342-
type PacketTooBigError struct {
343-
ICMPPacket []byte
344-
}
345-
346-
func (e *PacketTooBigError) Error() string { return "connect-ip: packet too big" }
347-
348-
func (c *Conn) Write(b []byte) (n int, err error) {
342+
// WritePacket writes an IP packet to the stream.
343+
// If sending the packet fails, it might return an ICMP packet.
344+
// It is the caller's responsibility to send the ICMP packet to the sender.
345+
func (c *Conn) WritePacket(b []byte) (icmp []byte, err error) {
349346
data, err := c.composeDatagram(b)
350347
if err != nil {
351348
log.Printf("dropping proxied packet (%d bytes) that can't be proxied: %s", len(b), err)
352-
return 0, nil
349+
return nil, nil
353350
}
354351
if err := c.str.SendDatagram(data); err != nil {
355352
if errors.Is(err, &quic.DatagramTooLargeError{}) {
356353
icmpPacket, err := composeICMPTooLargePacket(b, minMTU)
357354
if err != nil {
358355
log.Printf("failed to compose ICMP too large packet: %s", err)
359356
}
360-
return 0, &PacketTooBigError{ICMPPacket: icmpPacket}
357+
return icmpPacket, nil
361358
}
362359
select {
363360
case <-c.closeChan:
364-
return 0, c.closeErr
361+
return nil, c.closeErr
365362
default:
366-
return 0, err
363+
return nil, err
367364
}
368365
}
369-
return len(b), nil
366+
return nil, nil
370367
}
371368

372369
func (c *Conn) composeDatagram(b []byte) ([]byte, error) {

conn_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,7 @@ func TestSendLargeDatagrams(t *testing.T) {
275275
Protocol: 17,
276276
}).Marshal()
277277
require.NoError(t, err)
278-
_, err = conn.Write(data)
279-
var pktTooBigErr *PacketTooBigError
280-
require.ErrorAs(t, err, &pktTooBigErr)
281-
require.NotNil(t, pktTooBigErr.ICMPPacket)
278+
icmp, err := conn.WritePacket(data)
279+
require.NoError(t, err)
280+
require.NotNil(t, icmp)
282281
}

integration/client/client.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ func proxy(ipconn *connectip.Conn, dev *water.Interface) error {
204204
go func() {
205205
for {
206206
b := make([]byte, 1500)
207-
n, err := ipconn.Read(b)
207+
n, err := ipconn.ReadPacket(b)
208208
if err != nil {
209209
errChan <- fmt.Errorf("failed to read from connection: %w", err)
210210
return
@@ -226,19 +226,17 @@ func proxy(ipconn *connectip.Conn, dev *water.Interface) error {
226226
return
227227
}
228228
log.Printf("read %d bytes from TUN", n)
229-
if _, err := ipconn.Write(b[:n]); err != nil {
230-
var tooBigErr *connectip.PacketTooBigError
231-
if errors.As(err, &tooBigErr) {
232-
if len(tooBigErr.ICMPPacket) > 0 {
233-
if _, err := dev.Write(tooBigErr.ICMPPacket); err != nil {
234-
log.Printf("faield to write ICMP packet to %s: %v", dev.Name(), err)
235-
}
236-
}
237-
continue
238-
}
229+
icmp, err := ipconn.WritePacket(b[:n])
230+
if err != nil {
239231
errChan <- fmt.Errorf("failed to write to connection: %w", err)
240232
return
241233
}
234+
if len(icmp) > 0 {
235+
log.Printf("sending ICMP packet on %s", dev.Name())
236+
if _, err := dev.Write(icmp); err != nil {
237+
log.Printf("failed to write ICMP packet: %v", err)
238+
}
239+
}
242240
}
243241
}()
244242

integration/proxy/proxy.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package main
55
import (
66
"context"
77
"crypto/tls"
8-
"errors"
98
"fmt"
109
"log"
1110
"net"
@@ -217,7 +216,7 @@ func handleConn(conn *connectip.Conn, addr netip.Addr, route netip.Prefix, ipPro
217216
go func() {
218217
for {
219218
b := make([]byte, 1500)
220-
n, err := conn.Read(b)
219+
n, err := conn.ReadPacket(b)
221220
if err != nil {
222221
errChan <- fmt.Errorf("failed to read from connection: %w", err)
223222
return
@@ -239,21 +238,16 @@ func handleConn(conn *connectip.Conn, addr netip.Addr, route netip.Prefix, ipPro
239238
return
240239
}
241240
log.Printf("read %d bytes from %s", n, ifaceName)
242-
if _, err := conn.Write(b[:n]); err != nil {
243-
var tooBigErr *connectip.PacketTooBigError
244-
if errors.As(err, &tooBigErr) {
245-
if len(tooBigErr.ICMPPacket) > 0 {
246-
if err := sendOnSocket(serverSocketSend, tooBigErr.ICMPPacket); err != nil {
247-
errChan <- fmt.Errorf("writing to server socket: %w", err)
248-
return
249-
}
250-
}
251-
continue
252-
}
253-
241+
icmp, err := conn.WritePacket(b[:n])
242+
if err != nil {
254243
errChan <- fmt.Errorf("failed to write to connection: %w", err)
255244
return
256245
}
246+
if len(icmp) > 0 {
247+
if err := sendOnSocket(serverSocketSend, icmp); err != nil {
248+
log.Printf("failed to send ICMP packet: %v", err)
249+
}
250+
}
257251
}
258252
}()
259253

proxy_test.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ func TestTTLs(t *testing.T) {
159159
}
160160
packetTTL1, err := hdrTTL1.Marshal()
161161
require.NoError(t, err)
162-
_, err = client.Write(packetTTL1)
162+
icmp, err := client.WritePacket(packetTTL1)
163163
require.NoError(t, err)
164+
require.Empty(t, icmp)
164165

165166
// now send a packet with TTL 42
166167
hdr := &ipv4.Header{
@@ -171,11 +172,12 @@ func TestTTLs(t *testing.T) {
171172
}
172173
packet, err := hdr.Marshal()
173174
require.NoError(t, err)
174-
_, err = client.Write(packet)
175+
icmp, err = client.WritePacket(packet)
175176
require.NoError(t, err)
177+
require.Empty(t, icmp)
176178

177179
receivedPacket := make([]byte, 1500)
178-
n, err := server.Read(receivedPacket)
180+
n, err := server.ReadPacket(receivedPacket)
179181
require.NoError(t, err)
180182
receivedPacket = receivedPacket[:n]
181183

@@ -204,8 +206,9 @@ func TestTTLs(t *testing.T) {
204206
0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // Source IP
205207
0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88, // Destination IP
206208
}
207-
_, err := client.Write(packetHopLimit1)
209+
icmp, err := client.WritePacket(packetHopLimit1)
208210
require.NoError(t, err)
211+
require.Empty(t, icmp)
209212

210213
// now send a packet with Hop Limit 42
211214
packet := []byte{
@@ -215,11 +218,12 @@ func TestTTLs(t *testing.T) {
215218
0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // Source IP
216219
0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88, // Destination IP
217220
}
218-
_, err = client.Write(packet)
221+
icmp, err = client.WritePacket(packet)
219222
require.NoError(t, err)
223+
require.Empty(t, icmp)
220224

221225
receivedPacket := make([]byte, 1500)
222-
n, err := server.Read(receivedPacket)
226+
n, err := server.ReadPacket(receivedPacket)
223227
require.NoError(t, err)
224228
receivedPacket = receivedPacket[:n]
225229

@@ -269,9 +273,9 @@ func TestClosing(t *testing.T) {
269273
}),
270274
net.ErrClosed,
271275
)
272-
_, err = client.Read([]byte{0})
276+
_, err = client.ReadPacket([]byte{0})
273277
require.ErrorIs(t, err, net.ErrClosed)
274-
_, err = client.Write(ipv6Packet)
278+
_, err = client.WritePacket(ipv6Packet)
275279
require.ErrorIs(t, err, net.ErrClosed)
276280

277281
select {
@@ -288,8 +292,8 @@ func TestClosing(t *testing.T) {
288292
t.Fatal("timeout")
289293
}
290294

291-
_, err = server.Read([]byte{0})
295+
_, err = server.ReadPacket([]byte{0})
292296
require.ErrorIs(t, err, net.ErrClosed)
293-
_, err = server.Write(ipv6Packet)
297+
_, err = server.WritePacket(ipv6Packet)
294298
require.ErrorIs(t, err, net.ErrClosed)
295299
}

0 commit comments

Comments
 (0)