Skip to content

Commit

Permalink
feat: Initial TssVoteHandler (#16061)
Browse files Browse the repository at this point in the history
Signed-off-by: Derek Riley <[email protected]>
Signed-off-by: Neeharika-Sompalli <[email protected]>
Co-authored-by: Neeharika-Sompalli <[email protected]>
  • Loading branch information
derektriley and Neeharika-Sompalli authored Oct 25, 2024
1 parent 0500bcd commit 76c54fc
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.hedera.node.app.spi.workflows.record.StreamBuilder;
import com.swirlds.config.api.Configuration;
import com.swirlds.state.spi.info.NetworkInfo;
import com.swirlds.state.spi.info.NodeInfo;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.time.Instant;
Expand Down Expand Up @@ -551,6 +552,12 @@ static void throwIfMissingPayerId(@NonNull final TransactionBody body) {
@NonNull
Map<AccountID, Long> dispatchPaidRewards();

/**
* Returns the {@link NodeInfo} for the node this transaction is created from.
* @return the node info
*/
NodeInfo creatorInfo();

/**
* Whether a dispatch should be throttled at consensus. True for everything except certain dispatches
* internal to the EVM which are only constrained by gas.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@

import static java.util.Objects.requireNonNull;

import com.hedera.hapi.node.state.roster.Roster;
import com.hedera.hapi.node.state.roster.RosterEntry;
import com.hedera.hapi.node.state.tss.TssVoteMapKey;
import com.hedera.hapi.node.transaction.TransactionBody;
import com.hedera.hapi.services.auxiliary.tss.TssVoteTransactionBody;
import com.hedera.node.app.spi.workflows.HandleContext;
import com.hedera.node.app.spi.workflows.HandleException;
import com.hedera.node.app.spi.workflows.PreCheckException;
import com.hedera.node.app.spi.workflows.PreHandleContext;
import com.hedera.node.app.spi.workflows.TransactionHandler;
import com.hedera.node.app.tss.stores.WritableTssStore;
import com.hedera.pbj.runtime.io.buffer.Bytes;
import com.swirlds.platform.state.service.ReadableRosterStore;
import edu.umd.cs.findbugs.annotations.NonNull;
import javax.inject.Inject;
import javax.inject.Singleton;
Expand All @@ -35,6 +41,7 @@
*/
@Singleton
public class TssVoteHandler implements TransactionHandler {

@Inject
public TssVoteHandler() {
// Dagger2
Expand All @@ -53,5 +60,61 @@ public void pureChecks(@NonNull final TransactionBody txn) throws PreCheckExcept
@Override
public void handle(@NonNull final HandleContext context) throws HandleException {
requireNonNull(context);
final var txBody = context.body().tssVoteOrThrow();
final var tssBaseStore = context.storeFactory().writableStore(WritableTssStore.class);
final TssVoteMapKey tssVoteMapKey = new TssVoteMapKey(
txBody.targetRosterHash(), context.creatorInfo().nodeId());
if (tssBaseStore.exists(tssVoteMapKey)) {
// Duplicate vote
return;
}

if (!TssVoteHandler.hasReachedThreshold(txBody, context)) {
tssBaseStore.put(tssVoteMapKey, txBody);
}
}

/**
* Check if the threshold number of votes (totaling at least 1/3 of weight) have already been received for the
* candidate roster, all with the same vote byte array.
*
* @param tssVoteTransaction the TssVoteTransaction to check
* @param context the HandleContext
* @return true if the threshold has been reached, false otherwise
*/
public static boolean hasReachedThreshold(
@NonNull final TssVoteTransactionBody tssVoteTransaction, @NonNull final HandleContext context) {
final var rosterStore = context.storeFactory().readableStore(ReadableRosterStore.class);

final Roster activeRoster = rosterStore.getActiveRoster();
if (activeRoster == null) {
throw new IllegalArgumentException("No active roster found");
}
// Get the target roster from the TssVoteTransactionBody
final Bytes targetRosterHash = tssVoteTransaction.targetRosterHash();

// Also get the total active roster weight
long activeRosterTotalWeight = 0;
// Initialize a counter for the total weight of votes with the same vote byte array
long voteWeight = 0L;
final var tssBaseStore = context.storeFactory().writableStore(WritableTssStore.class);
// For every node in the active roster, check if there is a vote for the target roster hash
for (final RosterEntry rosterEntry : activeRoster.rosterEntries()) {
activeRosterTotalWeight += rosterEntry.weight();
final var tssVoteMapKey = new TssVoteMapKey(targetRosterHash, rosterEntry.nodeId());
if (tssBaseStore.exists(tssVoteMapKey)) {
final var vote = tssBaseStore.getVote(tssVoteMapKey);
// If the vote byte array matches the one in the TssVoteTransaction, add the weight of the vote to the
// counter
if (vote.tssVote().equals(tssVoteTransaction.tssVote())) {
voteWeight += rosterEntry.weight();
}
}
}

// Check if the total weight of votes with the same vote byte array is at least 1/3 of the
// total weight of the network
// Adding a +1 to the threshold to account for rounding errors.
return voteWeight >= (activeRosterTotalWeight / 3) + ((activeRosterTotalWeight % 3) == 0 ? 0 : 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ public boolean tryToChargePayer(final long amount) {
return feeAccumulator.chargeNetworkFee(payerId, amount);
}

@NonNull
@Override
public @NonNull Configuration configuration() {
public Configuration configuration() {
return config;
}

Expand Down Expand Up @@ -462,6 +463,11 @@ public Map<AccountID, Long> dispatchPaidRewards() {
return dispatchPaidRewards == null ? emptyMap() : dispatchPaidRewards;
}

@Override
public NodeInfo creatorInfo() {
return creatorInfo;
}

private <T> T dispatchForRecord(
@NonNull final TransactionBody childTxBody,
@NonNull final Class<T> recordBuilderClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,201 @@
package com.hedera.node.app.tss.handlers;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mock.Strictness.LENIENT;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.hedera.hapi.node.state.roster.Roster;
import com.hedera.hapi.node.state.roster.RosterEntry;
import com.hedera.hapi.node.state.tss.TssVoteMapKey;
import com.hedera.hapi.node.transaction.TransactionBody;
import com.hedera.hapi.services.auxiliary.tss.TssVoteTransactionBody;
import com.hedera.node.app.spi.store.StoreFactory;
import com.hedera.node.app.spi.workflows.HandleContext;
import com.hedera.node.app.spi.workflows.HandleException;
import com.hedera.node.app.spi.workflows.PreHandleContext;
import com.hedera.node.app.tss.stores.WritableTssStore;
import com.hedera.pbj.runtime.io.buffer.Bytes;
import com.swirlds.platform.state.service.ReadableRosterStore;
import com.swirlds.state.spi.info.NodeInfo;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
class TssVoteHandlerTest {
@Mock
private TssSubmissions submissionManager;

@Mock
private PreHandleContext preHandleContext;

@Mock
@Mock(strictness = LENIENT)
private HandleContext handleContext;

private TssVoteHandler subject;
@Mock
private WritableTssStore tssBaseStore;

@Mock
private ReadableRosterStore rosterStore;

@Mock
private TssVoteTransactionBody tssVoteTransactionBody;

@Mock
private TransactionBody transactionBody;

@Mock
private StoreFactory storeFactory;

@Mock(strictness = LENIENT)
private NodeInfo nodeInfo;

private TssVoteHandler tssVoteHandler;

@BeforeEach
void setUp() {
subject = new TssVoteHandler();
MockitoAnnotations.openMocks(this);
tssVoteHandler = new TssVoteHandler();
when(handleContext.creatorInfo()).thenReturn(nodeInfo);
when(nodeInfo.nodeId()).thenReturn(1L);
}

@Test
void handleDoesNotThrowWhenValidContext() throws HandleException {
when(handleContext.body()).thenReturn(transactionBody);
when(transactionBody.tssVoteOrThrow()).thenReturn(tssVoteTransactionBody);
when(handleContext.storeFactory()).thenReturn(storeFactory);
when(storeFactory.writableStore(WritableTssStore.class)).thenReturn(tssBaseStore);
;

when(tssVoteTransactionBody.targetRosterHash()).thenReturn(Bytes.EMPTY);
when(tssBaseStore.exists(any(TssVoteMapKey.class))).thenReturn(false);

try (MockedStatic<TssVoteHandler> mockedStatic = mockStatic(TssVoteHandler.class)) {
mockedStatic
.when(() -> TssVoteHandler.hasReachedThreshold(any(), any()))
.thenReturn(false);
tssVoteHandler.handle(handleContext);
}

verify(tssBaseStore).put(any(TssVoteMapKey.class), eq(tssVoteTransactionBody));
}

@Test
void handleReturnsWhenDuplicateVoteExists() throws HandleException {
when(handleContext.body()).thenReturn(transactionBody);
when(transactionBody.tssVoteOrThrow()).thenReturn(tssVoteTransactionBody);
when(handleContext.storeFactory()).thenReturn(storeFactory);
when(storeFactory.writableStore(WritableTssStore.class)).thenReturn(tssBaseStore);
when(tssVoteTransactionBody.targetRosterHash()).thenReturn(Bytes.EMPTY);
when(tssBaseStore.exists(any(TssVoteMapKey.class))).thenReturn(true);

tssVoteHandler.handle(handleContext);

verify(tssBaseStore, never()).put(any(TssVoteMapKey.class), eq(tssVoteTransactionBody));
}

@Test
void hasReachedThresholdReturnsFalseWhenThresholdIsNotMet() {
// Setup in-memory data
final RosterEntry rosterEntry1 = new RosterEntry(1L, 1L, null, null, List.of());
final RosterEntry rosterEntry2 = new RosterEntry(2L, 4L, null, null, List.of());
final RosterEntry rosterEntry3 = new RosterEntry(3L, 2L, null, null, List.of());
final Roster roster = new Roster(List.of(rosterEntry1, rosterEntry2, rosterEntry3));
final TssVoteTransactionBody voteTransactionBody =
new TssVoteTransactionBody(Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY);
final TssVoteTransactionBody voteTransactionBody2 =
new TssVoteTransactionBody(Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.fromHex("01"));
final TssVoteTransactionBody voteTransactionBody3 =
new TssVoteTransactionBody(Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.fromHex("02"));

// Setup stores
final Map<TssVoteMapKey, TssVoteTransactionBody> voteStore = new HashMap<>();
voteStore.put(new TssVoteMapKey(Bytes.EMPTY, 1L), voteTransactionBody);
voteStore.put(new TssVoteMapKey(Bytes.EMPTY, 2L), voteTransactionBody2);
voteStore.put(new TssVoteMapKey(Bytes.EMPTY, 3L), voteTransactionBody3);

// Mock behavior
when(handleContext.storeFactory()).thenReturn(storeFactory);
when(storeFactory.writableStore(WritableTssStore.class)).thenReturn(tssBaseStore);
when(storeFactory.readableStore(ReadableRosterStore.class)).thenReturn(rosterStore);
when(rosterStore.getActiveRoster()).thenReturn(roster);
when(tssBaseStore.exists(any(TssVoteMapKey.class)))
.thenAnswer(invocation -> voteStore.containsKey(invocation.getArgument(0)));
when(tssBaseStore.getVote(any(TssVoteMapKey.class)))
.thenAnswer(invocation -> voteStore.get(invocation.getArgument(0)));

final boolean result = TssVoteHandler.hasReachedThreshold(voteTransactionBody, handleContext);

assertFalse(result, "Threshold should not be met");
}

@Test
void hasReachedThresholdReturnsTrueWhenThresholdIsMet() {
// Setup in-memory data
final RosterEntry rosterEntry1 = new RosterEntry(1L, 1L, null, null, List.of());
final RosterEntry rosterEntry2 = new RosterEntry(2L, 2L, null, null, List.of());
final RosterEntry rosterEntry3 = new RosterEntry(3L, 3L, null, null, List.of());
final Roster roster = new Roster(List.of(rosterEntry1, rosterEntry2, rosterEntry3));
final TssVoteTransactionBody voteTransactionBody =
new TssVoteTransactionBody(Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY);
final TssVoteTransactionBody voteTransactionBody2 =
new TssVoteTransactionBody(Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.fromHex("01"));
final TssVoteTransactionBody voteTransactionBody3 =
new TssVoteTransactionBody(Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY);

// Setup stores
final Map<TssVoteMapKey, TssVoteTransactionBody> voteStore = new HashMap<>();
voteStore.put(new TssVoteMapKey(Bytes.EMPTY, 1L), voteTransactionBody);
voteStore.put(new TssVoteMapKey(Bytes.EMPTY, 2L), voteTransactionBody2);
voteStore.put(new TssVoteMapKey(Bytes.EMPTY, 3L), voteTransactionBody3);

// Mock behavior
when(handleContext.storeFactory()).thenReturn(storeFactory);
when(storeFactory.writableStore(WritableTssStore.class)).thenReturn(tssBaseStore);
when(storeFactory.readableStore(ReadableRosterStore.class)).thenReturn(rosterStore);
when(rosterStore.getActiveRoster()).thenReturn(roster);
when(tssBaseStore.exists(any(TssVoteMapKey.class)))
.thenAnswer(invocation -> voteStore.containsKey(invocation.getArgument(0)));
when(tssBaseStore.getVote(any(TssVoteMapKey.class)))
.thenAnswer(invocation -> voteStore.get(invocation.getArgument(0)));

boolean result = TssVoteHandler.hasReachedThreshold(voteTransactionBody, handleContext);

assertTrue(result, "Threshold should be met");
}

@Test
void nothingImplementedYet() {
assertDoesNotThrow(() -> subject.preHandle(preHandleContext));
assertDoesNotThrow(() -> subject.pureChecks(tssVote()));
assertDoesNotThrow(() -> subject.handle(handleContext));
void preHandleDoesNotThrowWhenContextIsValid() {
assertDoesNotThrow(() -> tssVoteHandler.preHandle(preHandleContext));
}

private TransactionBody tssVote() {
return TransactionBody.DEFAULT;
@Test
void pureChecksDoesNotThrowWhenTransactionBodyIsValid() {
assertDoesNotThrow(() -> tssVoteHandler.pureChecks(transactionBody));
}

@Test
void hasReachedThresholdThrowsIllegalArgumentExceptionWhenActiveRosterIsNull() {
when(handleContext.storeFactory()).thenReturn(storeFactory);
when(storeFactory.readableStore(ReadableRosterStore.class)).thenReturn(rosterStore);
when(rosterStore.getActiveRoster()).thenReturn(null);

TssVoteTransactionBody voteTransactionBody =
new TssVoteTransactionBody(Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY, Bytes.EMPTY);

assertThrows(
IllegalArgumentException.class,
() -> TssVoteHandler.hasReachedThreshold(voteTransactionBody, handleContext));
}
}
2 changes: 1 addition & 1 deletion platform-sdk/docs/proposals/TSS-Ledger-Id/TSS-Ledger-Id.md
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ Outputs:

1. If voting is closed for the target roster or the vote is a second vote from the originating node, do nothing.
2. Add the `TssVoteTransaction` to the list for the target roster.
3. If the voting threshold is met by at least 1/2 consensus weight voting yes:
3. If the voting threshold is met by at least 1/3 consensus weight voting yes:
1. add the target roster hash to the` `votingClosed` set.
2. Non-Dynamic Address Book Semantics
1. if `keyActiveRoster` is false, do nothing here, rely on the startup logic to rotate the candidate roster to
Expand Down

0 comments on commit 76c54fc

Please sign in to comment.