diff --git a/libagent/device/trezor.py b/libagent/device/trezor.py index 62f0bcd..5d30685 100644 --- a/libagent/device/trezor.py +++ b/libagent/device/trezor.py @@ -3,11 +3,9 @@ import binascii import logging -import mnemonic import semver from . import interface -from .. import util log = logging.getLogger(__name__) @@ -28,66 +26,8 @@ class Trezor(interface.Device): required_version = '>=1.4.0' ui = None # can be overridden by device's users - - def _override_pin_handler(self, conn): - if self.ui is None: - return - - def new_handler(_): - try: - scrambled_pin = self.ui.get_pin() - result = self._defs.PinMatrixAck(pin=scrambled_pin) - if not set(scrambled_pin).issubset('123456789'): - raise self._defs.PinException( - None, 'Invalid scrambled PIN: {!r}'.format(result.pin)) - return result - except: # noqa - conn.init_device() - raise - - conn.callback_PinMatrixRequest = new_handler - - cached_passphrase_ack = util.ExpiringCache(seconds=float('inf')) cached_state = None - def _override_passphrase_handler(self, conn): - if self.ui is None: - return - - def new_handler(msg): - try: - if msg.on_device is True: - return self._defs.PassphraseAck() - ack = self.__class__.cached_passphrase_ack.get() - if ack: - log.debug('re-using cached %s passphrase', self) - return ack - - passphrase = self.ui.get_passphrase() - passphrase = mnemonic.Mnemonic.normalize_string(passphrase) - ack = self._defs.PassphraseAck(passphrase=passphrase) - - length = len(ack.passphrase) - if length > 50: - msg = 'Too long passphrase ({} chars)'.format(length) - raise ValueError(msg) - - self.__class__.cached_passphrase_ack.set(ack) - return ack - except: # noqa - conn.init_device() - raise - - conn.callback_PassphraseRequest = new_handler - - def _override_state_handler(self, conn): - def callback_PassphraseStateRequest(msg): - log.debug('caching state from %r', msg) - self.__class__.cached_state = msg.state - return self._defs.PassphraseStateAck() - - conn.callback_PassphraseStateRequest = callback_PassphraseStateRequest - def _verify_version(self, connection): f = connection.features log.debug('connected to %s %s', self, f.device_id) @@ -113,10 +53,8 @@ class Trezor(interface.Device): log.debug('using transport: %s', transport) for _ in range(5): # Retry a few times in case of PIN failures connection = self._defs.Client(transport=transport, + ui=self.ui, state=self.__class__.cached_state) - self._override_pin_handler(connection) - self._override_passphrase_handler(connection) - self._override_state_handler(connection) self._verify_version(connection) try: @@ -132,7 +70,8 @@ class Trezor(interface.Device): def close(self): """Close connection.""" - self.conn.close() + self.__class__.cached_state = self.conn.state + super().close() def pubkey(self, identity, ecdh=False): """Return public key.""" @@ -140,8 +79,10 @@ class Trezor(interface.Device): log.debug('"%s" getting public key (%s) from %s', identity.to_string(), curve_name, self) addr = identity.get_bip32_address(ecdh=ecdh) - result = self.conn.get_public_node( - n=addr, ecdsa_curve_name=curve_name) + result = self._defs.get_public_node( + self.conn, + n=addr, + ecdsa_curve_name=curve_name) log.debug('result: %s', result) return bytes(result.node.public_key) @@ -157,7 +98,8 @@ class Trezor(interface.Device): log.debug('"%s" signing %r (%s) on %s', identity.to_string(), blob, curve_name, self) try: - result = self.conn.sign_identity( + result = self._defs.sign_identity( + self.conn, identity=self._identity_proto(identity), challenge_hidden=blob, challenge_visual='', @@ -166,7 +108,7 @@ class Trezor(interface.Device): assert len(result.signature) == 65 assert result.signature[:1] == b'\x00' return bytes(result.signature[1:]) - except self._defs.CallException as e: + except self._defs.TrezorFailure as e: msg = '{} error: {}'.format(self, e) log.debug(msg, exc_info=True) raise interface.DeviceError(msg) @@ -177,7 +119,8 @@ class Trezor(interface.Device): log.debug('"%s" shared session key (%s) for %r from %s', identity.to_string(), curve_name, pubkey, self) try: - result = self.conn.get_ecdh_session_key( + result = self._defs.get_ecdh_session_key( + self.conn, identity=self._identity_proto(identity), peer_public_key=pubkey, ecdsa_curve_name=curve_name) @@ -185,7 +128,7 @@ class Trezor(interface.Device): assert len(result.session_key) in {65, 33} # NIST256 or Curve25519 assert result.session_key[:1] == b'\x04' return bytes(result.session_key) - except self._defs.CallException as e: + except self._defs.TrezorFailure as e: msg = '{} error: {}'.format(self, e) log.debug(msg, exc_info=True) raise interface.DeviceError(msg) diff --git a/libagent/device/trezor_defs.py b/libagent/device/trezor_defs.py index 82f5be8..b295637 100644 --- a/libagent/device/trezor_defs.py +++ b/libagent/device/trezor_defs.py @@ -4,19 +4,72 @@ import os import logging -from trezorlib.client import CallException, PinException -from trezorlib.client import TrezorClient as Client -from trezorlib.messages import IdentityType, PassphraseAck, PinMatrixAck, PassphraseStateAck - -try: - from trezorlib.transport import get_transport -except ImportError: - from trezorlib.device import TrezorDevice - get_transport = TrezorDevice.find_by_path +import mnemonic +import semver +import trezorlib log = logging.getLogger(__name__) +if semver.match(trezorlib.__version__, ">=0.11.0"): + from trezorlib.client import TrezorClient as Client + from trezorlib.exceptions import TrezorFailure, PinException + from trezorlib.transport import get_transport + from trezorlib.messages import IdentityType + + from trezorlib.btc import get_public_node + from trezorlib.misc import sign_identity, get_ecdh_session_key + +else: + from trezorlib.client import (TrezorClient, CallException as TrezorFailure, + PinException) + from trezorlib.messages import IdentityType + from trezorlib import messages + from trezorlib.transport import get_transport + + get_public_node = TrezorClient.get_public_node + sign_identity = TrezorClient.sign_identity + get_ecdh_session_key = TrezorClient.get_ecdh_session_key + + class Client(TrezorClient): + def __init__(self, transport, ui, state=None): + super().__init__(transport, state=state) + self.ui = ui + + def callback_PinMatrixRequest(self, msg): + try: + pin = self.ui.get_pin(msg.type) + if not pin.isdigit(): + raise PinException( + None, 'Invalid scrambled PIN: {!r}'.format(pin)) + return messages.PinMatrixAck(pin=pin) + except: # noqa + self.init_device() + raise + + def callback_PassphraseRequest(self, msg): + try: + if msg.on_device is True: + return messages.PassphraseAck() + + passphrase = self.ui.get_passphrase() + passphrase = mnemonic.Mnemonic.normalize_string(passphrase) + + length = len(passphrase) + if length > 50: + msg = 'Too long passphrase ({} chars)'.format(length) + raise ValueError(msg) + + return messages.PassphraseAck(passphrase=passphrase) + except: # noqa + self.init_device() + raise + + def callback_PassphraseStateRequest(self, msg): + self.state = msg.state + return messages.PassphraseStateAck() + + def find_device(): """Selects a transport based on `TREZOR_PATH` environment variable. diff --git a/libagent/device/ui.py b/libagent/device/ui.py index b99d03b..61a4f6b 100644 --- a/libagent/device/ui.py +++ b/libagent/device/ui.py @@ -24,7 +24,7 @@ class UI: self.options_getter = create_default_options_getter() self.device_name = device_type.__name__ - def get_pin(self, name=None): + def get_pin(self, _code=None): """Ask the user for (scrambled) PIN.""" description = ( 'Use the numeric keypad to describe number positions.\n' @@ -33,21 +33,25 @@ class UI: ' 4 5 6\n' ' 1 2 3') return interact( - title='{} PIN'.format(name or self.device_name), + title='{} PIN'.format(self.device_name), prompt='PIN:', description=description, binary=self.pin_entry_binary, options=self.options_getter()) - def get_passphrase(self, name=None): + def get_passphrase(self): """Ask the user for passphrase.""" return interact( - title='{} passphrase'.format(name or self.device_name), + title='{} passphrase'.format(self.device_name), prompt='Passphrase:', description=None, binary=self.passphrase_entry_binary, options=self.options_getter()) + def button_request(self, _code=None): + # XXX: show notification to the user? + pass + def create_default_options_getter(): """Return current TTY and DISPLAY settings for GnuPG pinentry."""