Skip to content

Commit 494b60b

Browse files
authored
feat: implement TLS verification for incoming and outgoing pool connections on both sides
1 parent 24588b4 commit 494b60b

File tree

3 files changed

+119
-99
lines changed

3 files changed

+119
-99
lines changed

internal/client.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,15 @@ func (c *Client) commonStart() error {
156156
})
157157
go c.tunnelPool.ClientManager()
158158

159+
// 判断数据流向
159160
if c.dataFlow == "+" {
160-
// 初始化目标监听器
161161
if err := c.initTargetListener(); err != nil {
162162
return fmt.Errorf("commonStart: initTargetListener failed: %w", err)
163163
}
164164
go c.commonLoop()
165165
}
166+
167+
// 启动共用控制
166168
if err := c.commonControl(); err != nil {
167169
return fmt.Errorf("commonStart: commonControl failed: %w", err)
168170
}

internal/common.go

Lines changed: 110 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type Common struct {
3131
mu sync.Mutex // 互斥锁
3232
logger *logs.Logger // 日志记录器
3333
tlsCode string // TLS模式代码
34+
tlsConfig *tls.Config // TLS配置
3435
runMode string // 运行模式
3536
dataFlow string // 数据流向
3637
tunnelKey string // 隧道密钥
@@ -714,6 +715,14 @@ func (c *Common) healthCheck() error {
714715
ticker := time.NewTicker(reportInterval)
715716
defer ticker.Stop()
716717

718+
go func() {
719+
select {
720+
case <-c.ctx.Done():
721+
case <-ticker.C:
722+
c.incomingVerify()
723+
}
724+
}()
725+
717726
for c.ctx.Err() == nil {
718727
// 尝试获取锁
719728
if !c.mu.TryLock() {
@@ -763,6 +772,57 @@ func (c *Common) healthCheck() error {
763772
return fmt.Errorf("healthCheck: context error: %w", c.ctx.Err())
764773
}
765774

775+
// incomingVerify 入口连接验证
776+
func (c *Common) incomingVerify() {
777+
for c.ctx.Err() == nil {
778+
if c.tunnelPool.Ready() {
779+
break
780+
}
781+
select {
782+
case <-c.ctx.Done():
783+
continue
784+
case <-time.After(50 * time.Millisecond):
785+
}
786+
}
787+
788+
if c.tlsConfig == nil || len(c.tlsConfig.Certificates) == 0 {
789+
return
790+
}
791+
792+
cert := c.tlsConfig.Certificates[0]
793+
if len(cert.Certificate) == 0 {
794+
return
795+
}
796+
797+
// 打印证书指纹
798+
c.logger.Info("TLS cert verified: %v", c.formatCertFingerprint(cert.Certificate[0]))
799+
800+
id, testConn, err := c.tunnelPool.IncomingGet(poolGetTimeout)
801+
if err != nil {
802+
return
803+
}
804+
defer testConn.Close()
805+
806+
// 构建并发送验证信号
807+
verifyURL := &url.URL{
808+
Scheme: "np",
809+
Host: c.tunnelTCPConn.RemoteAddr().String(),
810+
Path: url.PathEscape(id),
811+
Fragment: "v", // TLS验证
812+
}
813+
814+
if c.ctx.Err() == nil && c.tunnelTCPConn != nil {
815+
c.mu.Lock()
816+
_, err = c.tunnelTCPConn.Write(c.encode([]byte(verifyURL.String())))
817+
c.mu.Unlock()
818+
if err != nil {
819+
return
820+
}
821+
}
822+
823+
c.logger.Debug("TLS verify signal: cid %v -> %v", id, c.tunnelTCPConn.RemoteAddr())
824+
}
825+
766826
// commonLoop 共用处理循环
767827
func (c *Common) commonLoop() {
768828
for c.ctx.Err() == nil {
@@ -1043,48 +1103,8 @@ func (c *Common) commonOnce() error {
10431103
// 处理信号
10441104
switch signalURL.Fragment {
10451105
case "v": // 验证
1046-
for c.ctx.Err() == nil {
1047-
if c.tunnelPool.Ready() {
1048-
break
1049-
}
1050-
select {
1051-
case <-c.ctx.Done():
1052-
continue
1053-
case <-time.After(50 * time.Millisecond):
1054-
}
1055-
}
1056-
id := strings.TrimPrefix(signalURL.Path, "/")
1057-
if unescapedID, err := url.PathUnescape(id); err != nil {
1058-
c.logger.Error("commonOnce: unescape id failed: %v", err)
1059-
continue
1060-
} else {
1061-
id = unescapedID
1062-
}
1063-
c.logger.Debug("TLS verify signal: cid %v <- %v", id, c.tunnelTCPConn.RemoteAddr())
1064-
1065-
testConn, err := c.tunnelPool.OutgoingGet(id, poolGetTimeout)
1066-
if err != nil {
1067-
c.logger.Error("commonOnce: request timeout: %v", err)
1068-
c.tunnelPool.AddError()
1069-
continue
1070-
}
1071-
1072-
if testConn != nil {
1073-
tlsConn, ok := testConn.(*tls.Conn)
1074-
if !ok {
1075-
c.logger.Error("commonOnce: connection is not TLS")
1076-
continue
1077-
}
1078-
1079-
state := tlsConn.ConnectionState()
1080-
if len(state.PeerCertificates) == 0 {
1081-
c.logger.Error("commonOnce: no peer certificates found")
1082-
continue
1083-
}
1084-
1085-
// 打印证书指纹
1086-
c.logger.Info("TLS cert verified: %v", c.formatCertFingerprint(state.PeerCertificates[0].Raw))
1087-
testConn.Close()
1106+
if c.tlsCode == "1" || c.tlsCode == "2" {
1107+
go c.outgoingVerify(signalURL)
10881108
}
10891109
case "1": // TCP
10901110
if c.disableTCP != "1" {
@@ -1132,6 +1152,54 @@ func (c *Common) commonOnce() error {
11321152
return fmt.Errorf("commonOnce: context error: %w", c.ctx.Err())
11331153
}
11341154

1155+
// outgoingVerify 出口连接验证
1156+
func (c *Common) outgoingVerify(signalURL *url.URL) {
1157+
for c.ctx.Err() == nil {
1158+
if c.tunnelPool.Ready() {
1159+
break
1160+
}
1161+
select {
1162+
case <-c.ctx.Done():
1163+
continue
1164+
case <-time.After(50 * time.Millisecond):
1165+
}
1166+
}
1167+
1168+
id := strings.TrimPrefix(signalURL.Path, "/")
1169+
if unescapedID, err := url.PathUnescape(id); err != nil {
1170+
c.logger.Error("outgoingVerify: unescape id failed: %v", err)
1171+
return
1172+
} else {
1173+
id = unescapedID
1174+
}
1175+
c.logger.Debug("TLS verify signal: cid %v <- %v", id, c.tunnelTCPConn.RemoteAddr())
1176+
1177+
testConn, err := c.tunnelPool.OutgoingGet(id, poolGetTimeout)
1178+
if err != nil {
1179+
c.logger.Error("outgoingVerify: request timeout: %v", err)
1180+
c.tunnelPool.AddError()
1181+
return
1182+
}
1183+
defer testConn.Close()
1184+
1185+
if testConn != nil {
1186+
tlsConn, ok := testConn.(*tls.Conn)
1187+
if !ok {
1188+
c.logger.Error("outgoingVerify: connection is not TLS")
1189+
return
1190+
}
1191+
1192+
state := tlsConn.ConnectionState()
1193+
if len(state.PeerCertificates) == 0 {
1194+
c.logger.Error("outgoingVerify: no peer certificates found")
1195+
return
1196+
}
1197+
1198+
// 打印证书指纹
1199+
c.logger.Info("TLS cert verified: %v", c.formatCertFingerprint(state.PeerCertificates[0].Raw))
1200+
}
1201+
}
1202+
11351203
// commonTCPOnce 共用处理单个TCP请求
11361204
func (c *Common) commonTCPOnce(signalURL *url.URL) {
11371205
id := strings.TrimPrefix(signalURL.Path, "/")

internal/server.go

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ import (
2323

2424
// Server 实现服务端模式功能
2525
type Server struct {
26-
Common // 继承共享功能
27-
tlsConfig *tls.Config // TLS配置
28-
clientIP string // 客户端IP
26+
Common // 继承共享功能
27+
clientIP string // 客户端IP
2928
}
3029

3130
// NewServer 创建新的服务端实例
3231
func NewServer(parsedURL *url.URL, tlsCode string, tlsConfig *tls.Config, logger *logs.Logger) (*Server, error) {
3332
server := &Server{
3433
Common: Common{
3534
tlsCode: tlsCode,
35+
tlsConfig: tlsConfig,
3636
logger: logger,
3737
signalChan: make(chan string, semaphoreLimit),
3838
tcpBufferPool: &sync.Pool{
@@ -51,7 +51,6 @@ func NewServer(parsedURL *url.URL, tlsCode string, tlsConfig *tls.Config, logger
5151
pingURL: &url.URL{Scheme: "np", Fragment: "i"},
5252
pongURL: &url.URL{Scheme: "np", Fragment: "o"},
5353
},
54-
tlsConfig: tlsConfig,
5554
}
5655
if err := server.initConfig(parsedURL); err != nil {
5756
return nil, fmt.Errorf("newServer: initConfig failed: %w", err)
@@ -151,61 +150,12 @@ func (s *Server) start() error {
151150
reportInterval)
152151
go s.tunnelPool.ServerManager()
153152

154-
// 验证TLS证书指纹
155-
if s.tlsCode == "1" || s.tlsCode == "2" {
156-
if s.tlsConfig == nil || len(s.tlsConfig.Certificates) == 0 {
157-
return fmt.Errorf("start: tlsConfig missing certificates for TLS verification")
158-
}
159-
160-
cert := s.tlsConfig.Certificates[0]
161-
if len(cert.Certificate) == 0 {
162-
return fmt.Errorf("start: no certificates found in tlsConfig for TLS verification")
163-
}
164-
165-
// 打印证书指纹
166-
s.logger.Info("TLS cert verified: %v", s.formatCertFingerprint(cert.Certificate[0]))
167-
168-
for s.ctx.Err() == nil {
169-
if s.tunnelPool.Ready() {
170-
break
171-
}
172-
select {
173-
case <-s.ctx.Done():
174-
return fmt.Errorf("start: context error: %w", s.ctx.Err())
175-
case <-time.After(50 * time.Millisecond):
176-
}
177-
}
178-
179-
id, testConn, err := s.tunnelPool.IncomingGet(poolGetTimeout)
180-
if err != nil {
181-
return fmt.Errorf("start: failed to get test connection from pool: %w", err)
182-
}
183-
184-
// 构建并发送验证信号
185-
verifyURL := &url.URL{
186-
Scheme: "np",
187-
Host: s.tunnelTCPConn.RemoteAddr().String(),
188-
Path: url.PathEscape(id),
189-
Fragment: "v", // TLS验证
190-
}
191-
192-
if s.ctx.Err() == nil && s.tunnelTCPConn != nil {
193-
s.mu.Lock()
194-
_, err = s.tunnelTCPConn.Write(s.encode([]byte(verifyURL.String())))
195-
s.mu.Unlock()
196-
if err != nil {
197-
testConn.Close()
198-
return fmt.Errorf("start: write TLS verify signal failed: %w", err)
199-
}
200-
}
201-
202-
s.logger.Debug("TLS verify signal: cid %v -> %v", id, s.tunnelTCPConn.RemoteAddr())
203-
testConn.Close()
204-
}
205-
153+
// 判断数据流向
206154
if s.dataFlow == "-" {
207155
go s.commonLoop()
208156
}
157+
158+
// 启动共用控制
209159
if err := s.commonControl(); err != nil {
210160
return fmt.Errorf("start: commonControl failed: %w", err)
211161
}

0 commit comments

Comments
 (0)