Skip to content

Commit 60cd348

Browse files
authored
feat(connector): changeAccount endpoint (#7053)
* feat(connector)_: change shared address command fixes #7048 * feat(connector)_: add chainIDSwitched signal fixes #7048 * feat(connector)_: add factory methods for commands fixes #7048 * feat(connector)_: reduce coupling * remove NetworkManager dependency from Commands fixes #7048
1 parent 31c94c9 commit 60cd348

21 files changed

+414
-148
lines changed

services/connector/api.go

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,80 +5,58 @@ import (
55
"errors"
66
"fmt"
77

8+
"github.com/status-im/status-go/services/connector/chainutils"
89
"github.com/status-im/status-go/services/connector/commands"
910
persistence "github.com/status-im/status-go/services/connector/database"
1011
)
1112

1213
var (
1314
ErrInvalidResponseFromForwardedRpc = errors.New("invalid response from forwarded RPC")
1415
ErrCannotOverrideClientIDForHttpConnection = errors.New("cannot override clientId for HTTP connection")
16+
ErrNotAllowedForUntrustedConnection = errors.New("cannot call from untrusted connection")
1517
ErrEmptyClientIDFromTrustedConnection = errors.New("trusted connection must provide a clientId")
1618
)
1719

1820
type API struct {
19-
s *Service
20-
r *CommandRegistry
21-
c *commands.ClientSideHandler
21+
s *Service
22+
r *CommandRegistry
23+
c *commands.ClientSideHandler
24+
changeAccountCommand *commands.ChangeAccountCommand
2225
}
2326

2427
func NewAPI(s *Service) *API {
2528
r := NewCommandRegistry()
2629
c := commands.NewClientSideHandler(s.db)
2730

2831
// Transactions and signing
29-
r.Register("eth_sendTransaction", &commands.SendTransactionCommand{
30-
EthClientGetter: s.ethClientGetter,
31-
FeeManager: s.feeManager,
32-
Db: s.db,
33-
ClientHandler: c,
34-
})
35-
r.Register("personal_sign", &commands.SignCommand{
36-
Db: s.db,
37-
ClientHandler: c,
38-
})
39-
r.Register("eth_signTypedData_v4", &commands.SignCommand{
40-
Db: s.db,
41-
ClientHandler: c,
42-
})
32+
r.Register("eth_sendTransaction", commands.NewSendTransactionCommand(s.db, s.ethClientGetter, s.feeManager, c))
33+
r.Register("personal_sign", commands.NewSignCommand(s.db, c))
34+
r.Register("eth_signTypedData_v4", commands.NewSignCommand(s.db, c))
4335

4436
// Accounts query and dapp permissions
4537
// NOTE: Some dApps expect same behavior for both eth_accounts and eth_requestAccounts
46-
accountsCommand := &commands.RequestAccountsCommand{
47-
ClientHandler: c,
48-
Db: s.db,
49-
}
38+
accountsCommand := commands.NewRequestAccountsCommand(s.db, c)
5039
r.Register("eth_accounts", accountsCommand)
5140
r.Register("eth_requestAccounts", accountsCommand)
5241

5342
// Active chain per dapp management
54-
r.Register("eth_chainId", &commands.ChainIDCommand{
55-
Db: s.db,
56-
NetworkManager: s.nm,
57-
})
58-
r.Register("net_version", &commands.NetVersionCommand{
59-
Db: s.db,
60-
NetworkManager: s.nm,
61-
})
62-
r.Register("wallet_switchEthereumChain", &commands.SwitchEthereumChainCommand{
63-
Db: s.db,
64-
NetworkManager: s.nm,
65-
})
43+
defaultChainIDGetter := chainutils.NewNetworkManagerAdapter(s.nm)
44+
r.Register("eth_chainId", commands.NewChainIDCommand(s.db, defaultChainIDGetter))
45+
r.Register("net_version", commands.NewNetVersionCommand(s.db, defaultChainIDGetter))
46+
r.Register("wallet_switchEthereumChain", commands.NewSwitchEthereumChainCommand(s.db, s.nm))
6647

6748
// Permissions
68-
r.Register("wallet_requestPermissions", &commands.RequestPermissionsCommand{
69-
Db: s.db,
70-
})
71-
r.Register("wallet_getPermissions", &commands.GetPermissionsCommand{
72-
Db: s.db,
73-
})
74-
r.Register("wallet_revokePermissions", &commands.RevokePermissionsCommand{
75-
Db: s.db,
76-
})
49+
r.Register("wallet_requestPermissions", commands.NewRequestPermissionsCommand(s.db))
50+
r.Register("wallet_getPermissions", commands.NewGetPermissionsCommand(s.db))
51+
r.Register("wallet_revokePermissions", commands.NewRevokePermissionsCommand(s.db))
52+
53+
changeAccountCommand := commands.NewChangeAccountCommand(s.db)
7754

7855
return &API{
79-
s: s,
80-
r: r,
81-
c: c,
56+
s: s,
57+
r: r,
58+
c: c,
59+
changeAccountCommand: changeAccountCommand,
8260
}
8361
}
8462

@@ -168,3 +146,10 @@ func (api *API) SignAccepted(args commands.SignAcceptedArgs) error {
168146
func (api *API) SignRejected(args commands.RejectedArgs) error {
169147
return api.c.SignRejected(args)
170148
}
149+
150+
func (api *API) ChangeAccount(ctx context.Context, args commands.ChangeAccountArgs) error {
151+
if IsUntrustedConnection(ctx) {
152+
return ErrNotAllowedForUntrustedConnection
153+
}
154+
return api.changeAccountCommand.Execute(args)
155+
}

services/connector/api_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,67 @@ func TestCallRPC_TrustedConnectionWithClientID(t *testing.T) {
9696
require.NoError(t, err)
9797
require.NotNil(t, result)
9898
}
99+
100+
func TestChangeAccount_UntrustedConnection(t *testing.T) {
101+
state := setupTests(t)
102+
103+
// Test untrusted connection (HTTP)
104+
ctx := WithConnectionType(context.Background(), ConnectionTypeHTTP)
105+
106+
args := commands.ChangeAccountArgs{
107+
URL: "https://example.com",
108+
ClientID: "test-client",
109+
}
110+
111+
err := state.api.ChangeAccount(ctx, args)
112+
require.Error(t, err)
113+
require.Equal(t, ErrNotAllowedForUntrustedConnection, err)
114+
}
115+
116+
func TestCallRPC_MethodNotAllowed(t *testing.T) {
117+
state := setupTests(t)
118+
119+
ctx := WithConnectionType(context.Background(), ConnectionTypeHTTP)
120+
121+
// Test a method that's not in the allowed list
122+
request := `{
123+
"method": "eth_subscribe",
124+
"params": [],
125+
"url": "https://example.com",
126+
"name": "Example DApp",
127+
"iconUrl": "https://example.com/icon.png"
128+
}`
129+
130+
result, err := state.api.CallRPC(ctx, request)
131+
require.Error(t, err)
132+
require.Nil(t, result)
133+
require.Contains(t, err.Error(), "not allowed")
134+
}
135+
136+
func TestCallRPC_InvalidJSON(t *testing.T) {
137+
state := setupTests(t)
138+
139+
ctx := WithConnectionType(context.Background(), ConnectionTypeHTTP)
140+
141+
request := `invalid json`
142+
143+
result, err := state.api.CallRPC(ctx, request)
144+
require.Error(t, err)
145+
require.Equal(t, "", result)
146+
}
147+
148+
func TestRecallDAppPermission_Deprecated(t *testing.T) {
149+
state := setupTests(t)
150+
151+
err := state.api.RecallDAppPermission("https://example.com")
152+
// Error is expected when dApp doesn't exist
153+
require.Error(t, err)
154+
}
155+
156+
func TestGetPermittedDAppsList(t *testing.T) {
157+
state := setupTests(t)
158+
159+
dapps, err := state.api.GetPermittedDAppsList()
160+
require.NoError(t, err)
161+
require.Empty(t, dapps)
162+
}

services/connector/chainutils/interfaces.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@ type FeeManager interface {
1818
type EthClientGetter interface {
1919
EthClient(chainID uint64) (ethclient.EthClientInterface, error)
2020
}
21+
22+
type DefaultChainIDGetter interface {
23+
GetDefaultChainID() (uint64, error)
24+
}

services/connector/chainutils/utils.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@ var (
1313
ErrUnsupportedNetwork = errors.New("unsupported network")
1414
)
1515

16+
// Implement DefaultChainIDGetter interface
17+
type NetworkManagerAdapter struct {
18+
networkManager *network.Manager
19+
}
20+
21+
func NewNetworkManagerAdapter(networkManager *network.Manager) *NetworkManagerAdapter {
22+
return &NetworkManagerAdapter{
23+
networkManager: networkManager,
24+
}
25+
}
26+
27+
func (a *NetworkManagerAdapter) GetDefaultChainID() (uint64, error) {
28+
return GetDefaultChainID(a.networkManager)
29+
}
30+
1631
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
1732
func GetSupportedChainIDs(networkManager *network.Manager) ([]uint64, error) {
1833
activeNetworks, err := networkManager.GetActiveNetworks()

services/connector/commands/accounts.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ import (
1010
)
1111

1212
type AccountsCommand struct {
13-
Db *sql.DB
13+
db *sql.DB
14+
}
15+
16+
func NewAccountsCommand(db *sql.DB) *AccountsCommand {
17+
return &AccountsCommand{
18+
db: db,
19+
}
1420
}
1521

1622
func FormatAccountAddressToResponse(address types.Address) []string {
@@ -23,7 +29,7 @@ func (c *AccountsCommand) Execute(ctx context.Context, request RPCRequest) (inte
2329
return "", err
2430
}
2531

26-
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
32+
dApp, err := persistence.SelectDApp(c.db, request.URL, request.ClientID)
2733
if err != nil {
2834
return "", err
2935
}

services/connector/commands/chain_id.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,21 @@ import (
44
"context"
55
"database/sql"
66

7-
"github.com/status-im/status-go/rpc/network"
87
"github.com/status-im/status-go/services/connector/chainutils"
98
persistence "github.com/status-im/status-go/services/connector/database"
109
walletCommon "github.com/status-im/status-go/services/wallet/common"
1110
)
1211

1312
type ChainIDCommand struct {
14-
NetworkManager *network.Manager
15-
Db *sql.DB
13+
defaultChainIDGetter chainutils.DefaultChainIDGetter
14+
db *sql.DB
15+
}
16+
17+
func NewChainIDCommand(db *sql.DB, defaultChainIDGetter chainutils.DefaultChainIDGetter) *ChainIDCommand {
18+
return &ChainIDCommand{
19+
db: db,
20+
defaultChainIDGetter: defaultChainIDGetter,
21+
}
1622
}
1723

1824
func (c *ChainIDCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) {
@@ -21,14 +27,14 @@ func (c *ChainIDCommand) Execute(ctx context.Context, request RPCRequest) (inter
2127
return "", err
2228
}
2329

24-
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
30+
dApp, err := persistence.SelectDApp(c.db, request.URL, request.ClientID)
2531
if err != nil {
2632
return "", err
2733
}
2834

2935
var chainId uint64
3036
if dApp == nil {
31-
chainId, err = chainutils.GetDefaultChainID(c.NetworkManager)
37+
chainId, err = c.defaultChainIDGetter.GetDefaultChainID()
3238
if err != nil {
3339
return "", err
3440
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package commands
2+
3+
import (
4+
"database/sql"
5+
6+
persistence "github.com/status-im/status-go/services/connector/database"
7+
"github.com/status-im/status-go/signal"
8+
)
9+
10+
type ChangeAccountCommand struct {
11+
db *sql.DB
12+
}
13+
14+
func NewChangeAccountCommand(db *sql.DB) *ChangeAccountCommand {
15+
return &ChangeAccountCommand{
16+
db: db,
17+
}
18+
}
19+
20+
func (c *ChangeAccountCommand) Execute(args ChangeAccountArgs) error {
21+
err := args.Validate()
22+
if err != nil {
23+
return err
24+
}
25+
26+
dApp, err := persistence.SelectDApp(c.db, args.URL, args.ClientID)
27+
if err != nil {
28+
return err
29+
}
30+
31+
if dApp == nil {
32+
return nil
33+
}
34+
35+
dApp.SharedAccount = args.Account
36+
37+
err = persistence.UpsertDApp(c.db, dApp)
38+
if err != nil {
39+
return err
40+
}
41+
42+
signal.SendConnectorAccountChanged(args.URL, args.ClientID, args.Account)
43+
return nil
44+
}

0 commit comments

Comments
 (0)