diff --git a/src/network/protocol/messages/cmpctblock.zig b/src/network/protocol/messages/cmpctblock.zig new file mode 100644 index 0000000..76439be --- /dev/null +++ b/src/network/protocol/messages/cmpctblock.zig @@ -0,0 +1,225 @@ +const std = @import("std"); +const protocol = @import("../lib.zig"); +const Transaction = @import("../../../types/transaction.zig"); + +const Sha256 = std.crypto.hash.sha2.Sha256; +const BlockHeader = @import("../../../types/block_header.zig"); +const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint; +const genericChecksum = @import("lib.zig").genericChecksum; + +pub const CmpctBlockMessage = struct { + header: BlockHeader, + nonce: u64, + short_ids: []u64, + prefilled_txns: []PrefilledTransaction, + + const Self = @This(); + + pub const PrefilledTransaction = struct { + index: usize, + tx: Transaction, + }; + + pub fn name() *const [12]u8 { + return protocol.CommandNames.CMPCTBLOCK; + } + + pub fn checksum(self: *const Self) [4]u8 { + return genericChecksum(self); + } + + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + allocator.free(self.short_ids); + for (self.prefilled_txns) |*txn| { + txn.tx.deinit(); + } + allocator.free(self.prefilled_txns); + } + + pub fn serializeToWriter(self: *const Self, w: anytype) !void { + comptime { + if (!@hasDecl(@TypeOf(w), "writeInt")) { + @compileError("Writer must have a writeInt method"); + } + } + + try self.header.serializeToWriter(w); + try w.writeInt(u64, self.nonce, .little); + + const short_ids_count = CompactSizeUint.new(self.short_ids.len); + try short_ids_count.encodeToWriter(w); + for (self.short_ids) |id| { + try w.writeInt(u64, id, .little); + } + + const prefilled_txns_count = CompactSizeUint.new(self.prefilled_txns.len); + try prefilled_txns_count.encodeToWriter(w); + + for (self.prefilled_txns) |txn| { + try CompactSizeUint.new(txn.index).encodeToWriter(w); + try txn.tx.serializeToWriter(w); + } + } + + pub fn serializeToSlice(self: *const Self, buffer: []u8) !void { + var fbs = std.io.fixedBufferStream(buffer); + try self.serializeToWriter(fbs.writer()); + } + + pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 { + const serialized_len = self.hintSerializedLen(); + if (serialized_len == 0) return &.{}; + const ret = try allocator.alloc(u8, serialized_len); + errdefer allocator.free(ret); + + try self.serializeToSlice(ret); + + return ret; + } + + pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self { + comptime { + if (!@hasDecl(@TypeOf(r), "readInt")) { + @compileError("Reader must have a readInt method"); + } + } + + const header = try BlockHeader.deserializeReader(r); + const nonce = try r.readInt(u64, .little); + + const short_ids_count = try CompactSizeUint.decodeReader(r); + const short_ids = try allocator.alloc(u64, short_ids_count.value()); + errdefer allocator.free(short_ids); + + for (short_ids) |*id| { + id.* = try r.readInt(u64, .little); + } + + const prefilled_txns_count = try CompactSizeUint.decodeReader(r); + const prefilled_txns = try allocator.alloc(PrefilledTransaction, prefilled_txns_count.value()); + errdefer allocator.free(prefilled_txns); + + for (prefilled_txns) |*txn| { + const index = try CompactSizeUint.decodeReader(r); + const tx = try Transaction.deserializeReader(allocator, r); + + txn.* = PrefilledTransaction{ + .index = index.value(), + .tx = tx, + }; + } + + return Self{ + .header = header, + .nonce = nonce, + .short_ids = short_ids, + .prefilled_txns = prefilled_txns, + }; + } + + pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self { + var fbs = std.io.fixedBufferStream(bytes); + return try Self.deserializeReader(allocator, fbs.reader()); + } + + pub fn hintSerializedLen(self: *const Self) usize { + var len: usize = 80 + 8; // BlockHeader + nonce + len += CompactSizeUint.new(self.short_ids.len).hint_encoded_len(); + len += self.short_ids.len * 8; + len += CompactSizeUint.new(self.prefilled_txns.len).hint_encoded_len(); + for (self.prefilled_txns) |txn| { + len += CompactSizeUint.new(txn.index).hint_encoded_len(); + len += txn.tx.hintEncodedLen(); + } + return len; + } + + pub fn eql(self: *const Self, other: *const Self) bool { + if (self.header.version != other.header.version or + !std.mem.eql(u8, &self.header.prev_block, &other.header.prev_block) or + !std.mem.eql(u8, &self.header.merkle_root, &other.header.merkle_root) or + self.header.timestamp != other.header.timestamp or + self.header.nbits != other.header.nbits or + self.header.nonce != other.header.nonce or + self.nonce != other.nonce) return false; + + if (self.short_ids.len != other.short_ids.len) return false; + for (self.short_ids, other.short_ids) |a, b| { + if (a != b) return false; + } + if (self.prefilled_txns.len != other.prefilled_txns.len) return false; + for (self.prefilled_txns, other.prefilled_txns) |a, b| { + if (a.index != b.index or !a.tx.eql(b.tx)) return false; + } + return true; + } +}; + +test "CmpctBlockMessage serialization and deserialization" { + const testing = std.testing; + const Hash = @import("../../../types/hash.zig"); + const Script = @import("../../../types/script.zig"); + const OutPoint = @import("../../../types/outpoint.zig"); + const OpCode = @import("../../../script/opcodes/constant.zig").Opcode; + + const test_allocator = testing.allocator; + + // Create a sample BlockHeader + const header = BlockHeader{ + .version = 1, + .prev_block = [_]u8{0} ** 32, // Zero-filled array of 32 bytes + .merkle_root = [_]u8{0} ** 32, // Zero-filled array of 32 bytes + .timestamp = 1631234567, + .nbits = 0x1d00ffff, + .nonce = 12345, + }; + + // Create sample short_ids + const short_ids = try test_allocator.alloc(u64, 2); + defer test_allocator.free(short_ids); + short_ids[0] = 123456789; + short_ids[1] = 987654321; + + // Create a sample Transaction + var tx = try Transaction.init(test_allocator); + defer tx.deinit(); + try tx.addInput(OutPoint{ .hash = Hash.newZeroed(), .index = 0 }); + { + var script_pubkey = try Script.init(test_allocator); + defer script_pubkey.deinit(); + try script_pubkey.push(&[_]u8{ OpCode.OP_DUP.toBytes(), OpCode.OP_HASH160.toBytes(), OpCode.OP_EQUALVERIFY.toBytes(), OpCode.OP_CHECKSIG.toBytes() }); + try tx.addOutput(50000, script_pubkey); + } + + // Create sample prefilled_txns + const prefilled_txns = try test_allocator.alloc(CmpctBlockMessage.PrefilledTransaction, 1); + defer test_allocator.free(prefilled_txns); + prefilled_txns[0] = .{ + .index = 0, + .tx = tx, + }; + + // Create CmpctBlockMessage + const msg = CmpctBlockMessage{ + .header = header, + .nonce = 9876543210, + .short_ids = short_ids, + .prefilled_txns = prefilled_txns, + }; + + // Test serialization + const serialized = try msg.serialize(test_allocator); + defer test_allocator.free(serialized); + + // Test deserialization + var deserialized = try CmpctBlockMessage.deserializeSlice(test_allocator, serialized); + defer deserialized.deinit(test_allocator); + + // Verify deserialized data + try std.testing.expect(msg.eql(&deserialized)); + + // Test hintSerializedLen + const hint_len = msg.hintSerializedLen(); + try testing.expect(hint_len > 0); + try testing.expect(hint_len == serialized.len); +} diff --git a/src/network/protocol/messages/lib.zig b/src/network/protocol/messages/lib.zig index a35b23c..c98dc18 100644 --- a/src/network/protocol/messages/lib.zig +++ b/src/network/protocol/messages/lib.zig @@ -48,6 +48,7 @@ pub const InventoryVector = struct { }; } }; +pub const CmpctBlockMessage = @import("cmpctblock.zig").CmpctBlockMessage; pub const MessageTypes = enum { version, @@ -67,6 +68,7 @@ pub const MessageTypes = enum { sendheaders, filterload, headers, + cmpctblock, }; pub const Message = union(MessageTypes) { @@ -87,6 +89,7 @@ pub const Message = union(MessageTypes) { sendheaders: SendHeadersMessage, filterload: FilterLoadMessage, headers: HeadersMessage, + cmpctblock: CmpctBlockMessage, pub fn name(self: Message) *const [12]u8 { return switch (self) { @@ -107,6 +110,7 @@ pub const Message = union(MessageTypes) { .sendheaders => |m| @TypeOf(m).name(), .filterload => |m| @TypeOf(m).name(), .headers => |m| @TypeOf(m).name(), + .cmpctblock => |m| @TypeOf(m).name(), }; } @@ -126,6 +130,7 @@ pub const Message = union(MessageTypes) { .block => |*m| m.deinit(allocator), .filteradd => |*m| m.deinit(allocator), .notfound => {}, + .cmpctblock => |*m| m.deinit(allocator), .sendheaders => {}, .filterload => {}, .headers => |*m| m.deinit(allocator), @@ -151,6 +156,7 @@ pub const Message = union(MessageTypes) { .sendheaders => |*m| m.checksum(), .filterload => |*m| m.checksum(), .headers => |*m| m.checksum(), + .cmpctblock => |*m| m.checksum(), }; } @@ -173,6 +179,7 @@ pub const Message = union(MessageTypes) { .sendheaders => |m| m.hintSerializedLen(), .filterload => |*m| m.hintSerializedLen(), .headers => |*m| m.hintSerializedLen(), + .cmpctblock => |*m| m.hintSerializedLen(), }; } }; diff --git a/src/network/wire/lib.zig b/src/network/wire/lib.zig index bf0d022..3267d19 100644 --- a/src/network/wire/lib.zig +++ b/src/network/wire/lib.zig @@ -141,6 +141,8 @@ pub fn receiveMessage( protocol.messages.Message{ .sendheaders = try protocol.messages.SendHeadersMessage.deserializeReader(allocator, r) } else if (std.mem.eql(u8, &command, protocol.messages.FilterLoadMessage.name())) protocol.messages.Message{ .filterload = try protocol.messages.FilterLoadMessage.deserializeReader(allocator, r) } + else if (std.mem.eql(u8, &command, protocol.messages.CmpctBlockMessage.name())) + protocol.messages.Message{ .cmpctblock = try protocol.messages.CmpctBlockMessage.deserializeReader(allocator, r) } else { try r.skipBytes(payload_len, .{}); // Purge the wire return error.UnknownMessage; @@ -579,3 +581,73 @@ test "ok_send_sendcmpct_message" { else => unreachable, } } + +test "ok_send_cmpctblock_message" { + const Transaction = @import("../../types/transaction.zig"); + const OutPoint = @import("../../types/outpoint.zig"); + const OpCode = @import("../../script/opcodes/constant.zig").Opcode; + const Hash = @import("../../types/hash.zig"); + const Script = @import("../../types/script.zig"); + const CmpctBlockMessage = @import("../protocol/messages/cmpctblock.zig").CmpctBlockMessage; + + const allocator = std.testing.allocator; + + // Create a sample BlockHeader + const header = BlockHeader{ + .version = 1, + .prev_block = [_]u8{0} ** 32, // Zero-filled array of 32 bytes + .merkle_root = [_]u8{0} ** 32, // Zero-filled array of 32 bytes + .timestamp = 1631234567, + .nbits = 0x1d00ffff, + .nonce = 12345, + }; + + // Create sample short_ids + const short_ids = try allocator.alloc(u64, 2); + defer allocator.free(short_ids); + short_ids[0] = 123456789; + short_ids[1] = 987654321; + + // Create a sample Transaction + var tx = try Transaction.init(allocator); + defer tx.deinit(); + try tx.addInput(OutPoint{ .hash = Hash.newZeroed(), .index = 0 }); + { + var script_pubkey = try Script.init(allocator); + defer script_pubkey.deinit(); + try script_pubkey.push(&[_]u8{ OpCode.OP_DUP.toBytes(), OpCode.OP_HASH160.toBytes(), OpCode.OP_EQUALVERIFY.toBytes(), OpCode.OP_CHECKSIG.toBytes() }); + try tx.addOutput(50000, script_pubkey); + } + + // Create sample prefilled_txns + const prefilled_txns = try allocator.alloc(CmpctBlockMessage.PrefilledTransaction, 1); + defer allocator.free(prefilled_txns); + prefilled_txns[0] = .{ + .index = 0, + .tx = tx, + }; + + // Create CmpctBlockMessage + const msg = CmpctBlockMessage{ + .header = header, + .nonce = 9876543210, + .short_ids = short_ids, + .prefilled_txns = prefilled_txns, + }; + + // Test serialization + const serialized = try msg.serialize(allocator); + defer allocator.free(serialized); + + // Test deserialization + var deserialized = try CmpctBlockMessage.deserializeSlice(allocator, serialized); + defer deserialized.deinit(allocator); + + // Verify deserialized data + try std.testing.expect(msg.eql(&deserialized)); + + // Test hintSerializedLen + const hint_len = msg.hintSerializedLen(); + try std.testing.expect(hint_len > 0); + try std.testing.expect(hint_len == serialized.len); +}