Skip to content

Commit

Permalink
fix: Update Rewards method for Lido
Browse files Browse the repository at this point in the history
* refactor: Get rewards using linear search

* refactor: Get shares using linear search

* fix: Update deprecated package for RPCs

* fix: GetPublicRPCs randomize RPCs
  • Loading branch information
khalifaa55 authored Sep 12, 2024
1 parent 149bcdc commit 9563877
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 115 deletions.
8 changes: 4 additions & 4 deletions configs/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"math/rand"
"sort"
"time"
)

Expand Down Expand Up @@ -80,10 +81,9 @@ func GetPublicRPCs(network string) ([]string, error) {
shuffledRPCs := make([]string, len(rpcs.PublicRPCs))
copy(shuffledRPCs, rpcs.PublicRPCs)

// Shuffle the slice to randomize the order
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(shuffledRPCs), func(i, j int) {
shuffledRPCs[i], shuffledRPCs[j] = shuffledRPCs[j], shuffledRPCs[i]
// Randomize the slice order
sort.Slice(shuffledRPCs, func(i, j int) bool {
return rand.Float32() < 0.5
})

return shuffledRPCs, nil
Expand Down
28 changes: 28 additions & 0 deletions configs/networks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,31 @@ func TestSupportMEVBoost(t *testing.T) {
})
}
}

func TestGetPublicRPCs_Randomness(t *testing.T) {
networkRPCs = map[string]RPC{
"testnet": {PublicRPCs: []string{"rpc1", "rpc2", "rpc3", "rpc4", "rpc5"}},
}

// Run multiple times to check randomness
iterations := 100
results := make([]string, iterations)

for i := 0; i < iterations; i++ {
got, err := GetPublicRPCs("testnet")
if err != nil {
t.Fatalf("GetPublicRPCs() error = %v", err)
}
results[i] = got[0] // Store the first RPC of each result
}

// Check if we have different first elements (indicating randomness)
uniqueFirstElements := make(map[string]bool)
for _, r := range results {
uniqueFirstElements[r] = true
}

if len(uniqueFirstElements) < 2 {
t.Errorf("GetPublicRPCs() doesn't seem to randomize the order. Got %d unique first elements out of %d iterations", len(uniqueFirstElements), iterations)
}
}
16 changes: 16 additions & 0 deletions configs/public_rpcs.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
Copyright 2022 Nethermind
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package configs

type RPC struct {
Expand All @@ -14,6 +29,7 @@ var networkRPCs = map[string]RPC{
"https://rpc.mevblocker.io",
"https://ethereum-rpc.publicnode.com",
"https://rpc.flashbots.net",
"https://eth.drpc.org",
},
},
NetworkHolesky: {
Expand Down
48 changes: 18 additions & 30 deletions internal/lido/contracts/csfeedistributor/rewards.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"time"

Expand Down Expand Up @@ -85,16 +86,24 @@ func cumulativeFeeShares(treeCID string, nodeID *big.Int) (*big.Int, error) {
return nil, fmt.Errorf("error getting tree data: %v", err)
}

index, err := binarySearchNodeID(nodeID, treeData)
if err != nil {
return nil, fmt.Errorf("failed to find node ID: %v", err)
}

shares, err := convertTreeValuesToBigInt(treeData.Values[index].Value[1])
if err != nil {
return nil, fmt.Errorf("failed to convert shares: %v", err)
// Compare nodeOperatorID in tree with nodeId to get shares
for _, item := range treeData.Values {
if len(item.Value) == 2 {
nodeOperatorId, err1 := convertTreeValuesToBigInt(item.Value[0])
shares, err2 := convertTreeValuesToBigInt(item.Value[1])
if err1 != nil || err2 != nil {
log.Println("Error converting values:", err1, err2)
continue
}
if nodeOperatorId.Cmp(nodeID) == 0 {
log.Printf("shares: %v", shares)
return shares, nil
}
} else {
log.Println("Unexpected value format, expected 2 elements")
}
}
return shares, nil
return nil, fmt.Errorf("invalid nodeId")
}

func treeCID(network string) (string, error) {
Expand Down Expand Up @@ -141,27 +150,6 @@ func convertTreeValuesToBigInt(value interface{}) (*big.Int, error) {
return bigIntValue, nil
}

func binarySearchNodeID(nodeID *big.Int, treeData Tree) (int, error) {
// Compare nodeOperatorID in tree with nodeId to get shares
low, high := 0, len(treeData.Values)-1
for low <= high {
mid := (low + high) / 2
nodeOperatorId, err := convertTreeValuesToBigInt(treeData.Values[mid].Value[0])
if err != nil {
return 0, fmt.Errorf("failed to convert nodeOperatorId: %v", err)
}
cmp := nodeOperatorId.Cmp(nodeID)
if cmp == 0 {
return mid, nil
} else if cmp < 0 {
low = mid + 1
} else {
high = mid - 1
}
}
return 0, fmt.Errorf("invalid node ID")
}

func csFeeDistributorContract(network string) (*Csfeedistributor, *ethclient.Client, error) {
client, err := contracts.ConnectClient(network)
if err != nil {
Expand Down
81 changes: 0 additions & 81 deletions internal/lido/contracts/csfeedistributor/rewards_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,84 +113,3 @@ func TestConvertTreeValuesToBigInt(t *testing.T) {
}
}
}

func TestBinarySearchNodeID(t *testing.T) {
tcs := []struct {
name string
treeData Tree
nodeID *big.Int
expected int
expectErr bool
}{
{
name: "Node ID found at index 1",
treeData: Tree{
Values: []struct {
Value []interface{} `json:"value"`
TreeIndex int `json:"treeIndex"`
}{
{Value: []interface{}{1.0, 5000.0}, TreeIndex: 0},
{Value: []interface{}{2.0, 6000.0}, TreeIndex: 1},
{Value: []interface{}{30.0, 7000.0}, TreeIndex: 2},
},
},
nodeID: big.NewInt(2),
expected: 1,
},
{
name: "Node ID found at index 0",
treeData: Tree{
Values: []struct {
Value []interface{} `json:"value"`
TreeIndex int `json:"treeIndex"`
}{
{Value: []interface{}{100.0, 5000.0}, TreeIndex: 0},
{Value: []interface{}{200.0, 6000.0}, TreeIndex: 1},
{Value: []interface{}{330.0, 7000.0}, TreeIndex: 2},
},
},
nodeID: big.NewInt(100),
expected: 0,
},
{
name: "Node ID not found",
treeData: Tree{
Values: []struct {
Value []interface{} `json:"value"`
TreeIndex int `json:"treeIndex"`
}{
{Value: []interface{}{10.0, 5000}, TreeIndex: 0},
{Value: []interface{}{20.0, 6000}, TreeIndex: 1},
{Value: []interface{}{30.0, 7000}, TreeIndex: 2},
},
},
nodeID: big.NewInt(400),
expected: 0,
expectErr: true,
},
{
name: "empty tree",
treeData: Tree{
Values: []struct {
Value []interface{} `json:"value"`
TreeIndex int `json:"treeIndex"`
}{},
},
nodeID: big.NewInt(4),
expected: 0,
expectErr: true,
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
index, err := binarySearchNodeID(tc.nodeID, tc.treeData)
if (err != nil) != tc.expectErr {
t.Errorf("expected error: %v, got: %v", tc.expectErr, err)
}
if err == nil && index != tc.expected {
t.Errorf("expected index: %v, got: %v", tc.expected, index)
}
})
}
}

0 comments on commit 9563877

Please sign in to comment.