Skip to content

PSBT setup #138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions bitcoinutils/psbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from bitcoinutils.transactions import (
Transaction,
)

from bitcoinutils.utils import (
encode_varint,
vi_to_int
)

MAGIC_BYTES = b"psbt\xff"
SEPARATOR = b'\x00'

# Global key types
PSBT_GLOBAL_UNSIGNED_TX = b'\x00'
PSBT_GLOBAL_XPUB = b'\x01'
PSBT_GLOBAL_TX_VERSION = b'\x02'
PSBT_GLOBAL_FALLBACK_LOCKTIME = b'\x03'
PSBT_GLOBAL_INPUT_COUNT = b'\x04'
PSBT_GLOBAL_OUTPUT_COUNT = b'\x05'
PSBT_GLOBAL_TX_MODIFIABLE = b'\x06'
PSBT_GLOBAL_SP_ECDH_SHARE = b'\x07'
PSBT_GLOBAL_SP_DLEQ = b'\x08'
PSBT_GLOBAL_VERSION = b'\xFB'
PSBT_GLOBAL_PROPRIETARY = b'\xFC'

# Per-input key types
PSBT_IN_NON_WITNESS_UTXO = b'\x00'
PSBT_IN_WITNESS_UTXO = b'\x01'
PSBT_IN_PARTIAL_SIG = b'\x02'
PSBT_IN_SIGHASH_TYPE = b'\x03'
PSBT_IN_REDEEM_SCRIPT = b'\x04'
PSBT_IN_WITNESS_SCRIPT = b'\x05'
PSBT_IN_BIP32_DERIVATION = b'\x06'
PSBT_IN_FINAL_SCRIPTSIG = b'\x07'
PSBT_IN_FINAL_SCRIPTWITNESS = b'\x08'

# Per-output key types
PSBT_OUT_REDEEM_SCRIPT = b'\x00'
PSBT_OUT_WITNESS_SCRIPT = b'\x01'
PSBT_OUT_BIP32_DERIVATION = b'\x02'
PSBT_OUT_AMOUNT = b'\x03'
PSBT_OUT_SCRIPT = b'\x04'


class PSBT:
def __init__(self, maps: dict):
'''
Parameters
----------
maps : dict
A dictionary with the keys 'global', 'input' and 'output' containing the corresponding maps.'''
self.maps = maps
#TODO: add checks to validate psbt (will be added in future PRs)

@staticmethod
def serialize_key_val(key: bytes, val: bytes):
'''Serialize a key value pair, key, val should be bytes'''
return encode_varint(len(key)) + key + encode_varint(len(val)) + val

@staticmethod
def parse_key_value(s):
"""Parse a key-value pair from the PSBT stream."""
# Read the first byte to determine the key length
key_length_bytes = s.read(1)
key_length, _ = vi_to_int(key_length_bytes)
# If key length is 0, return None (indicates a separator)
if key_length == 0:
return None, None
# Read the key
key = s.read(key_length)

# Read the value length
val_length_bytes = s.read(1)
val_length, _ = vi_to_int(val_length_bytes)
# Read the value
val = s.read(val_length)

return key, val

def serialize(self):
psbt = MAGIC_BYTES
# Here we are including keytype and keydata in key, therefore serialize_key_val() works as intended
for key, val in sorted(self.maps['global'].items()):
psbt += self.serialize_key_val(key, val)
psbt += SEPARATOR
for inp in self.maps['input']:
for key, val in sorted(inp.items()):
psbt += self.serialize_key_val(key, val)
psbt += SEPARATOR
for out in self.maps['output']:
for key, val in sorted(out.items()):
psbt += self.serialize_key_val(key, val)
psbt += SEPARATOR
return psbt


@classmethod
def parse(cls, s):
if s.read(5) != MAGIC_BYTES:
raise ValueError('Invalid PSBT magic bytes')
maps = {'global': {}, 'input': [], 'output': []}

globals = True #To check if paresed key value is from global map
input_ind = 0
output_ind = 0

while globals or input_ind > 0 or output_ind > 0:
key, val = PSBT.parse_key_value(s)

if globals:
if key is None: #Separator is reached indicating end of global map
globals = False
continue

maps['global'][key] = val


if key == PSBT_GLOBAL_UNSIGNED_TX: #If unsigned transaction is found, intialize input and output maps
hex_val = val.hex()
transaction = Transaction.from_raw(hex_val)
input_ind = len(transaction.inputs)
output_ind = len(transaction.outputs)
# input_ind = 1
# output_ind = 1
maps['input'] = [{} for _ in range(input_ind)]
maps['output'] = [{} for _ in range(output_ind)]

elif input_ind > 0: # Means input map is being parsed
if key is None: #Separator is reached; indicating end of the particular input map, there can be multiple input maps
input_ind -= 1
continue

ind = input_ind - len(maps['input']) #Get the index of the input being parsed
maps['input'][ind][key] = val

elif output_ind > 0: # Means output map is being parsed

if key is None: #Separator is reached; indicating end of the particular output map, there can be multiple output maps
output_ind -= 1
continue

ind = output_ind - len(maps['output']) #Get the index of the output being parsed
maps['output'][ind][key] = val

return cls(maps)

#TODO: Add methods to parse and serialize psbt as b64 and hex (will be added in future PRs)
85 changes: 85 additions & 0 deletions tests/test_psbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import unittest
from io import BytesIO
from bitcoinutils.psbt import PSBT, MAGIC_BYTES, SEPARATOR, PSBT_GLOBAL_UNSIGNED_TX
from bitcoinutils.transactions import Transaction
from bitcoinutils.utils import encode_varint


class TestPSBT(unittest.TestCase):
def setUp(self):
# Sample PSBT maps for testing
self.sample_maps = {
'global': {
PSBT_GLOBAL_UNSIGNED_TX: bytes.fromhex(
"0100000001c3b5b9b07ec40d9e3f5edfa7e4f10b23bc653e5b6a5a1c79d1f4d232b3c6a29d000000006a473044022067e502e82d02e7a1a3b504897dfec4ea8df71a3b77cfe1b9cbf7d3f16a63642e02206e3b32b1e6b8f184a654bd22c6cb4a616274e0e44ed14e7f3e54d5e2d840cc6f012102a84c91d495bfecb17ea00e1dd6c634755643b95a09856c7cde4575a11b3a48e6ffffffff01a0860100000000001976a91489abcdefabbaabbaabbaabbaabbaabbaabbaabba88ac00000000"
),
},
'input': [
{b'\x00': b'\x01\x02\x03'}, # Example input map
],
'output': [
{b'\x00': b'\x04\x05\x06'}, # Example output map
]
}
self.psbt = PSBT(self.sample_maps)

def test_serialize(self):
"""Test if the PSBT object serializes correctly."""
serialized = self.psbt.serialize()

# Check if the serialized PSBT starts with the magic bytes
self.assertTrue(serialized.startswith(MAGIC_BYTES))

# Check if the global map is serialized correctly
for key, val in self.sample_maps['global'].items():
encoded_key_val = encode_varint(len(key)) + key + encode_varint(len(val)) + val
self.assertIn(encoded_key_val, serialized)

# Check if the input maps are serialized correctly
for inp in self.sample_maps['input']:
for key, val in inp.items():
encoded_key_val = encode_varint(len(key)) + key + encode_varint(len(val)) + val
self.assertIn(encoded_key_val, serialized)

# Check if the output maps are serialized correctly
for out in self.sample_maps['output']:
for key, val in out.items():
encoded_key_val = encode_varint(len(key)) + key + encode_varint(len(val)) + val
self.assertIn(encoded_key_val, serialized)


def test_parse(self):
"""Test if the PSBT object parses correctly."""
serialized = self.psbt.serialize()
parsed_psbt = PSBT.parse(BytesIO(serialized))

# Check if the parsed PSBT matches the original maps
self.assertEqual(parsed_psbt.maps['global'], self.sample_maps['global'])
self.assertEqual(parsed_psbt.maps['input'], self.sample_maps['input'])
self.assertEqual(parsed_psbt.maps['output'], self.sample_maps['output'])

def test_serialize_and_parse(self):
"""Test if serialization and parsing are consistent."""
serialized = self.psbt.serialize()
parsed_psbt = PSBT.parse(BytesIO(serialized))

# Serialize the parsed PSBT and compare with the original serialization
reserialized = parsed_psbt.serialize()
self.assertEqual(serialized, reserialized)

def test_parse_invalid_magic_bytes(self):
"""Test parsing with invalid magic bytes."""
invalid_psbt = b"abcd" + self.psbt.serialize()[4:] # Replace magic bytes
with self.assertRaises(ValueError) as context:
PSBT.parse(BytesIO(invalid_psbt))
self.assertEqual(str(context.exception), "Invalid PSBT magic bytes")

def test_parse_missing_separator(self):
"""Test parsing with missing separators."""
serialized = self.psbt.serialize().replace(SEPARATOR, b"") # Remove separators
with self.assertRaises(Exception): # Replace with a specific exception if implemented
PSBT.parse(BytesIO(serialized))


if __name__ == '__main__':
unittest.main()