Implement block transfer logic

This commit is contained in:
2024-03-16 21:56:12 +01:00
parent 44598fc030
commit f41710c3dd
3 changed files with 177 additions and 41 deletions

View File

@@ -17,12 +17,12 @@ class Transaction:
def from_bytes(transaction_raw): def from_bytes(transaction_raw):
assert len(transaction_raw) == 148 assert len(transaction_raw) == 148
return Transaction( return Transaction(
id: int.from_bytes(transaction_raw[0:4], "big"), id = int.from_bytes(transaction_raw[0:4], "big"),
sender: transaction_raw[4:36], sender = transaction_raw[4:36],
receiver: transaction_raw[36:68], receiver = transaction_raw[36:68],
amount: int.from_bytes(transaction_raw[68:76], "big"), amount = int.from_bytes(transaction_raw[68:76], "big"),
transaction_fee: int.from_bytes(transaction_raw[76:84], "big"), transaction_fee = int.from_bytes(transaction_raw[76:84], "big"),
signature: transaction_raw[84:148], signature = transaction_raw[84:148],
) )
def is_valid(self): def is_valid(self):
sender_pubkey = Ed25519PublicKey.from_public_bytes(self.sender) sender_pubkey = Ed25519PublicKey.from_public_bytes(self.sender)
@@ -66,25 +66,25 @@ class Block:
used_transaction_ids: set used_transaction_ids: set
valid: bool valid: bool
def from_bytes(self, block_raw): def from_bytes(block_raw):
assert len(block_raw) == 292 assert len(block_raw) == 292
transaction_raw = block_raw[144:292] transaction_raw = block_raw[144:292]
if transaction_raw == 148 * b"\0": if transaction_raw == 148 * b"\0":
transaction = None transaction = None
else: else:
transaction = Transaction.from_bytes(transaction_raw) transaction = Transaction.from_bytes(transaction_raw)
block = Block( return Block(
nonce: int.from_bytes(block_raw[0:8], "big"), nonce = int.from_bytes(block_raw[0:8], "big"),
timestamp: int.from_bytes(block_raw[8:16], "big"), timestamp = int.from_bytes(block_raw[8:16], "big"),
previous_hash: block_raw[16:48], previous_hash = block_raw[16:48],
message: block_raw[48:80], message = block_raw[48:80],
difficulty_sum: int.from_bytes(block_raw[80:112], "big"), difficulty_sum = int.from_bytes(block_raw[80:112], "big"),
miner_pubkey: block_raw[112:144], miner_pubkey = block_raw[112:144],
transaction: transaction, transaction = transaction,
own_hash: hashlib.sha256(block_raw).digest(), own_hash = hashlib.sha256(block_raw).digest(),
balances: None, balances = None,
used_transaction_ids: None, used_transaction_ids = None,
valid: False, valid = False,
) )
def validate(self, blockchain): def validate(self, blockchain):
if self.transaction is not None: if self.transaction is not None:
@@ -92,15 +92,18 @@ class Block:
return False return False
if self.previous_hash != 32 * b"\0": if self.previous_hash != 32 * b"\0":
prev_block = blockchain.get_block(self.previous_hash) prev_block = blockchain.get_block(self.previous_hash)
if prev_block is None:
return False
if not prev_block.valid: if not prev_block.valid:
return False return False
if self.timestamp <= prev_block.timestamp: if self.timestamp <= prev_block.timestamp:
return False return False
if self.timestamp > time.time(): if self.timestamp > time.time():
return False return False
if not self.transaction.is_valid_after_block(prev_block): if self.transaction is not None and not self.transaction.is_valid_after_block(prev_block):
return False return False
else: else:
prev_block = None
if self.transaction is not None: if self.transaction is not None:
return False return False
# check for the correct miner pubkey - which will become public at launch day # check for the correct miner pubkey - which will become public at launch day
@@ -128,10 +131,10 @@ class Block:
return return
balances = prev_block.balances.copy() balances = prev_block.balances.copy()
balances.setdefault(self.miner_pubkey, 0) balances.setdefault(self.miner_pubkey, 0)
balances[miner_pubkey] += 100 balances[self.miner_pubkey] += 100
t = self.transaction t = self.transaction
if t is not None: if t is not None:
balances[miner_pubkey] += t.transaction_fee balances[self.miner_pubkey] += t.transaction_fee
balances[t.sender] -= (t.amount + t.transaction_fee) balances[t.sender] -= (t.amount + t.transaction_fee)
balances[t.receiver] += t.amount balances[t.receiver] += t.amount
self.balances = balances self.balances = balances
@@ -174,21 +177,41 @@ class Blockchain:
self.__latest_block_hash = None self.__latest_block_hash = None
self.__lock = Lock() self.__lock = Lock()
def set_latest_block(self, block_hash): def set_latest_block(self, block_hash):
with self.__lock:
new_block = self.get_block(block_hash) new_block = self.get_block(block_hash)
assert new_block is not None assert new_block is not None
assert new_block.valid assert new_block.valid
if self.__latest_block_hash is not None: while True:
current_difficulty_sum = self.__latest_block_hash.get_difficulty_info(1, self) with self.__lock:
new_difficulty_sum = new_block.get_difficulty_info(1, self) latest_block_hash = self.__latest_block_hash
if latest_block_hash is not None:
latest_block = self.get_block(latest_block_hash)
current_difficulty_sum = latest_block.get_difficulty_info(1, self)[0]
new_difficulty_sum = new_block.get_difficulty_info(1, self)[0]
if new_difficulty_sum <= current_difficulty_sum: if new_difficulty_sum <= current_difficulty_sum:
return return False
with self.__lock:
if self.__latest_block_hash != latest_block_hash:
continue
self.__latest_block_hash = block_hash self.__latest_block_hash = block_hash
return True
def add_block(self, block_raw): def add_block(self, block_raw):
with self.__lock: with self.__lock:
block = Block.from_bytes(block_raw) block = Block.from_bytes(block_raw)
if block.own_hash not in self.__block_map:
self.__block_map[block.own_hash] = block self.__block_map[block.own_hash] = block
return block return self.__block_map[block.own_hash]
def get_block(self, hash): def get_block(self, hash):
with self.__lock: with self.__lock:
return self.__block_map.get(hash) return self.__block_map.get(hash)
def get_second_last_difficulty_sum(self):
with self.__lock:
latest_block_hash = self.__latest_block_hash
if latest_block_hash is None:
return 0
block = self.get_block(latest_block_hash)
return block.get_difficulty_info(1, self)[0]
def get_latest_block(self):
with self.__lock:
if self.__latest_block_hash is None:
return None
return self.__block_map[self.__latest_block_hash]

98
node.py
View File

@@ -1,6 +1,7 @@
#! /usr/bin/env python3 #! /usr/bin/env python3
import hashlib, random, socket, sys, threading, time import blockchain, hashlib, observer, random, socket, sys, threading, time
from _queue import Empty
DEFAULT_PORT = 62039 DEFAULT_PORT = 62039
@@ -88,11 +89,11 @@ def empty_transaction_list_hash():
current_hash = hashlib.sha256(entry).digest() current_hash = hashlib.sha256(entry).digest()
return current_hash return current_hash
def send_heartbeat(node, peer): def send_heartbeat(node, peer, b):
protocol_version = 2 * b"\0" protocol_version = 2 * b"\0"
capable_version = 2 * b"\0" capable_version = 2 * b"\0"
msg_type = b"\0" msg_type = b"\0"
difficulty_sum = 32 * b"\0" difficulty_sum = b.get_second_last_difficulty_sum().to_bytes(32, "big")
hash_value = empty_transaction_list_hash() hash_value = empty_transaction_list_hash()
if peer.partner is None: if peer.partner is None:
@@ -117,7 +118,7 @@ def define_partnership(peers):
if pairing_count % 2 == 1: if pairing_count % 2 == 1:
peers_to_pair[-1].partner = None peers_to_pair[-1].partner = None
def heartbeat(node): def heartbeat(node, b):
while True: while True:
heartbeats, partners = node.get_events() heartbeats, partners = node.get_events()
heartbeats = set(heartbeats) heartbeats = set(heartbeats)
@@ -169,11 +170,61 @@ def heartbeat(node):
define_partnership(node.peers) define_partnership(node.peers)
for i, peer in enumerate(node.peers): for i, peer in enumerate(node.peers):
wait_until(start_time + 60 * (i+1) / peer_count) wait_until(start_time + 60 * (i+1) / peer_count)
send_heartbeat(node, peer) send_heartbeat(node, peer, b)
if len(node.peers) == 0: if len(node.peers) == 0:
time.sleep(60) time.sleep(60)
def receiver(node): class NoReponseException(Exception):
pass
def request_retry(node, addr, request, subscription, condition):
for _ in range(10):
node.node_socket.sendto(request, addr)
try:
while True:
response = subscription.receive(1)
if condition(response):
break
return response
except Empty:
pass
raise NoReponseException()
def transfer_block(addr, node, receive_observer, b):
try:
block_list = []
request_block_hash = 32 * b"\0"
subscription = receive_observer.listen((addr[0:2], "block transfer"))
while True:
if request_block_hash != 32 * b"\0":
existing_block = b.get_block(request_block_hash)
if existing_block is not None:
if existing_block.valid:
break
request_block_hash = existing_block.previous_hash
if request_block_hash == 32*b"\0":
break
continue
request = b"\0\0\0\0\x01" + request_block_hash
def condition(response_hash):
if request_block_hash == 32*b"\0":
return True
return response_hash == request_block_hash
block_hash = request_retry(node, addr, request, subscription, condition)
block_list.append(block_hash)
request_block_hash = b.get_block(block_hash).previous_hash
if request_block_hash == 32*b"\0":
break
for block_hash in reversed(block_list):
if not b.get_block(block_hash).validate(b):
return
if b.set_latest_block(block_hash):
log("Got a new block")
except NoReponseException:
pass
def receiver(node, b):
receive_observer = observer.Observer()
while True: while True:
msg, addr = node.node_socket.recvfrom(4096) msg, addr = node.node_socket.recvfrom(4096)
sender = describe(addr[0], addr[1]) sender = describe(addr[0], addr[1])
@@ -202,6 +253,36 @@ def receiver(node):
partner_ip = socket.inet_ntop(socket.AF_INET6, partner_info[0:16]) partner_ip = socket.inet_ntop(socket.AF_INET6, partner_info[0:16])
partner_port = int.from_bytes(partner_info[16:18], "big") partner_port = int.from_bytes(partner_info[16:18], "big")
node.add_partner((partner_ip, partner_port)) node.add_partner((partner_ip, partner_port))
contained_difficulty_sum = int.from_bytes(msg[5:37])
my_difficulty_sum = b.get_second_last_difficulty_sum()
if contained_difficulty_sum > my_difficulty_sum:
log("beginning a block transfer ...")
threading.Thread(target = transfer_block, args=(addr, node, receive_observer, b)).start()
elif msg_type == 1:
# block request
if msg_len != 37:
log(f"Got a block request of wrong length ({msg_len} bytes from {sender}, but expected 37 bytes)")
block_hash = msg[5:37]
if block_hash == 32 * b"\0":
block_to_send = b.get_latest_block()
else:
block_to_send = b.get_block(block_hash)
if block_to_send is None:
continue
block_raw = block_to_send.get_block_raw()
response_msg = b"\0\0\0\0\x02" + block_raw
node.node_socket.sendto(response_msg, addr)
elif msg_type == 2:
# block transfer
if msg_len != 297:
log(f"Got a block transfer of wrong length ({msg_len} bytes from {sender}, but expected 297 bytes)")
block_raw = msg[5:297]
new_block = b.add_block(block_raw)
block_hash = new_block.own_hash
if new_block.validate(b) and b.set_latest_block(block_hash):
log("Got a new block")
identifier = (addr[0:2], "block transfer")
receive_observer.publish(identifier, block_hash)
else: else:
log(f"Got a udp message of unknown type from {sender}. (type {msg_type})") log(f"Got a udp message of unknown type from {sender}. (type {msg_type})")
@@ -223,8 +304,9 @@ def main():
log("Node is ready") log("Node is ready")
node = Node(node_socket, peers) node = Node(node_socket, peers)
heartbeat_thread = threading.Thread(target = heartbeat, args = (node,)) b = blockchain.Blockchain()
receiving_thread = threading.Thread(target = receiver, args = (node,)) heartbeat_thread = threading.Thread(target = heartbeat, args = (node, b))
receiving_thread = threading.Thread(target = receiver, args = (node, b))
heartbeat_thread.start() heartbeat_thread.start()
receiving_thread.start() receiving_thread.start()
heartbeat_thread.join() heartbeat_thread.join()

31
observer.py Normal file
View File

@@ -0,0 +1,31 @@
from multiprocessing import Lock, Queue
class Observer:
def __init__(self):
self.__receivers_list = {}
self.__lock = Lock()
def listen(self, identifier):
with self.__lock:
queue = Queue()
self.__receivers_list.setdefault(identifier, set())
self.__receivers_list[identifier].add(queue)
return Subscription(self, identifier, queue)
def publish(self, identifier, message):
if identifier in self.__receivers_list:
for queue in self.__receivers_list[identifier]:
queue.put(message)
def quit(self, identifer, queue):
with self.__lock:
self.__receivers_list[identifer].remove(queue)
if len(self.__receivers_list[identifer]) == 0:
del self.__receivers_list[identifer]
class Subscription:
def __init__(self, observer, identifier, queue):
self.__observer = observer
self.__identifier = identifier
self.__queue = queue
def receive(self, timeout):
return self.__queue.get(timeout=timeout)
def __del__(self):
self.__observer.quit(self.__identifier, self.__queue)