diff --git a/README.md b/README.md index af7ba6a..6bd309f 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,10 @@ $ ledgerctl -v run Bitcoin <= 9000 ``` +### Using BLE + +BLE scanning is disabled by default. It can be activated by setting an environment variable named `LEDGER_USE_BLE`. + ## Contributing ### Pre-commit checks diff --git a/ledgerwallet/client.py b/ledgerwallet/client.py index 4106844..09fd810 100644 --- a/ledgerwallet/client.py +++ b/ledgerwallet/client.py @@ -164,12 +164,14 @@ class NoLedgerDeviceException(Exception): class LedgerClient(object): def __init__(self, device=None, cla=0xE0, private_key=None): + self.device = None if device is None: devices = enumerate_devices() if len(devices) == 0: raise NoLedgerDeviceException("No Ledger device has been found.") device = devices[0] self.device = device + LOG.debug(self.device) self.cla = cla self._target_id = None self.scp = None @@ -179,8 +181,12 @@ def __init__(self, device=None, cla=0xE0, private_key=None): self.private_key = PrivateKey(private_key) self.device.open() + def __del__(self): + self.close() + def close(self): - self.device.close() + if self.device is not None: + self.device.close() def raw_exchange(self, data: bytes) -> bytes: LOG.debug("=> " + data.hex()) diff --git a/ledgerwallet/transport/__init__.py b/ledgerwallet/transport/__init__.py index 452cdf2..e339bb8 100644 --- a/ledgerwallet/transport/__init__.py +++ b/ledgerwallet/transport/__init__.py @@ -1,10 +1,11 @@ from contextlib import contextmanager +from .ble import BleDevice from .device import Device from .hid import HidDevice from .tcp import TcpDevice -DEVICE_CLASSES = [TcpDevice, HidDevice] +DEVICE_CLASSES = [TcpDevice, HidDevice, BleDevice] def enumerate_devices(): diff --git a/ledgerwallet/transport/ble.py b/ledgerwallet/transport/ble.py new file mode 100644 index 0000000..1ce661e --- /dev/null +++ b/ledgerwallet/transport/ble.py @@ -0,0 +1,137 @@ +import asyncio +import os +from typing import List + +from bleak import BleakClient, BleakScanner +from bleak.exc import BleakError + +HANDLE_CHAR_ENABLE_NOTIF = 13 +HANDLE_CHAR_WRITE = 16 +TAG_ID = b"\x05" + + +queue: asyncio.Queue = asyncio.Queue() + + +async def ble_discover(): + devices = await BleakScanner.discover(timeout=1.0) + return devices + + +def callback(sender, data): + response = bytes(data) + queue.put_nowait(response) + + +async def _get_client(ble_address: str) -> BleakClient: + device = await BleakScanner.find_device_by_address(ble_address, timeout=1.0) + if not device: + raise BleakError(f"Device with address {ble_address} could not be found.") + + client = BleakClient(device) + await client.connect() + + # register notification callback + # callback = lambda sender, data: queue.put_nowait(bytes(data)) + await client.start_notify(HANDLE_CHAR_ENABLE_NOTIF, callback) + + # enable notifications + await client.write_gatt_char(HANDLE_CHAR_WRITE, bytes.fromhex("0001"), True) + assert await queue.get() == b"\x00\x00\x00\x00\x00" + + # confirm that the MTU is 0x99 + await client.write_gatt_char(HANDLE_CHAR_WRITE, bytes.fromhex("0800000000"), True) + assert await queue.get() == b"\x08\x00\x00\x00\x01\x99" + + return client + + +async def _read() -> bytes: + response = await queue.get() + + assert len(response) >= 5 + assert response[0] == TAG_ID[0] + assert response[1:3] == b"\x00\x00" + total_size = int.from_bytes(response[3:5], "big") + + apdu = response[5:] + i = 1 + if len(apdu) < total_size: + assert total_size > len(response) - 5 + + response = await queue.get() + + assert len(response) >= 3 + assert response[0] == TAG_ID[0] + assert int.from_bytes(response[1:3], "big") == i + i += 1 + apdu += response[3:] + + assert len(apdu) == total_size + return apdu + + +async def _write(client: BleakClient, data: bytes, mtu: int = 0x99): + chunks: List[bytes] = [] + buffer = data + while buffer: + if not chunks: + size = 5 + else: + size = 3 + size = mtu - size + chunks.append(buffer[:size]) + buffer = buffer[size:] + + for i, chunk in enumerate(chunks): + header = TAG_ID + header += i.to_bytes(2, "big") + if i == 0: + header += len(data).to_bytes(2, "big") + await client.write_gatt_char(HANDLE_CHAR_WRITE, header + chunk, True) + + +class BleDevice(object): + def __init__(self, device): + self.device = device + self.loop = None + self.client = None + self.opened = False + + @classmethod + def enumerate_devices(cls): + if "LEDGER_USE_BLE" in os.environ: + loop = asyncio.get_event_loop() + discovered_devices = loop.run_until_complete(ble_discover()) + devices = [] + for device in discovered_devices: + if device.name is not None: + if device.name.startswith("Nano X"): + devices.append(BleDevice(device)) + return devices + else: + return [] + + def __str__(self): + return "[BLE Device] {} ({})".format(self.device.name, self.device.address) + + def open(self): + self.loop = asyncio.get_event_loop() + self.client = self.loop.run_until_complete(_get_client(self.device.address)) + self.opened = True + + def close(self): + if self.opened: + self.loop.run_until_complete(self.client.disconnect()) + self.opened = False + self.loop.close() + + def write(self, data: bytes): + self.loop.run_until_complete(_write(self.client, data)) + + def read(self) -> bytes: + return self.loop.run_until_complete(_read()) + + def exchange(self, data: bytes, timeout=1000): + self.write(data) + return self.read() diff --git a/ledgerwallet/transport/hid.py b/ledgerwallet/transport/hid.py index e078c98..0428025 100644 --- a/ledgerwallet/transport/hid.py +++ b/ledgerwallet/transport/hid.py @@ -22,6 +22,9 @@ def enumerate_devices(cls): devices.append(HidDevice(hid_device_path)) return devices + def __str__(self): + return "[HID Device] {}".format(self.path.decode()) + def get_name(self): return "hid:{}".format(self.path.decode()) diff --git a/ledgerwallet/transport/tcp.py b/ledgerwallet/transport/tcp.py index ace9f93..1f4a19f 100644 --- a/ledgerwallet/transport/tcp.py +++ b/ledgerwallet/transport/tcp.py @@ -28,6 +28,9 @@ def enumerate_devices(cls): else: return [] + def __str__(self): + return "[TCP Device] {}:{}".format(self.server, self.port) + def open(self): self.socket.connect((self.server, self.port)) diff --git a/pyproject.toml b/pyproject.toml index 59459fa..f45ee3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ dynamic = ["version", "description"] requires-python = ">=3.7" dependencies = [ + "bleak", "click >=8.0", "construct >=2.10", "cryptography >=2.5",