From f22c07e97059d7486a501da0206c438680a84043 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Sat, 18 Nov 2017 20:36:46 +0200 Subject: [PATCH] trezor: retry in case of invalid PIN --- libagent/device/trezor.py | 64 ++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/libagent/device/trezor.py b/libagent/device/trezor.py index 4d51eda..0e67ea8 100644 --- a/libagent/device/trezor.py +++ b/libagent/device/trezor.py @@ -62,7 +62,8 @@ class Trezor(interface.Device): 'Please enter PIN:') result = self._defs.PinMatrixAck(pin=scrambled_pin) if not set(result.pin).issubset('123456789'): - raise ValueError('Invalid scrambled PIN: {}'.format(result)) + raise self._defs.PinException( + None, 'Invalid scrambled PIN: {!r}'.format(result.pin)) return result conn.callback_PinMatrixRequest = new_handler @@ -79,35 +80,44 @@ class Trezor(interface.Device): conn.callback_PassphraseRequest = new_handler + def _verify_version(self, connection): + f = connection.features + log.debug('connected to %s %s', self, f.device_id) + log.debug('label : %s', f.label) + log.debug('vendor : %s', f.vendor) + current_version = '{}.{}.{}'.format(f.major_version, + f.minor_version, + f.patch_version) + log.debug('version : %s', current_version) + log.debug('revision : %s', binascii.hexlify(f.revision)) + if not semver.match(current_version, self.required_version): + fmt = ('Please upgrade your {} firmware to {} version' + ' (current: {})') + raise ValueError(fmt.format(self, self.required_version, + current_version)) + def connect(self): """Enumerate and connect to the first USB HID interface.""" for d in self._defs.Transport.enumerate(): log.debug('endpoint: %s', d) transport = self._defs.Transport(d) - connection = self._defs.Client(transport) - self._override_pin_handler(connection) - self._override_passphrase_handler(connection) - f = connection.features - log.debug('connected to %s %s', self, f.device_id) - log.debug('label : %s', f.label) - log.debug('vendor : %s', f.vendor) - current_version = '{}.{}.{}'.format(f.major_version, - f.minor_version, - f.patch_version) - log.debug('version : %s', current_version) - log.debug('revision : %s', binascii.hexlify(f.revision)) - if not semver.match(current_version, self.required_version): - fmt = ('Please upgrade your {} firmware to {} version' - ' (current: {})') - raise ValueError(fmt.format(self, self.required_version, - current_version)) - try: - connection.ping(msg='', pin_protection=True) # unlock PIN - except Exception as e: - log.exception('ping failed: %s', e) - connection.close() # so the next HID open() will succeed - raise - return connection + for _ in range(5): + connection = self._defs.Client(transport) + self._override_pin_handler(connection) + self._override_passphrase_handler(connection) + self._verify_version(connection) + + try: + connection.ping(msg='', pin_protection=True) # unlock PIN + return connection + except (self._defs.PinException, ValueError) as e: + log.error('Invalid PIN: %s, retrying...', e) + continue + except Exception as e: + log.exception('ping failed: %s', e) + connection.close() # so the next HID open() will succeed + raise + raise interface.NotFoundError('{} not connected'.format(self)) def close(self): @@ -120,8 +130,8 @@ class Trezor(interface.Device): log.debug('"%s" getting public key (%s) from %s', identity.to_string(), curve_name, self) addr = identity.get_bip32_address(ecdh=ecdh) - result = self.conn.get_public_node(n=addr, - ecdsa_curve_name=curve_name) + result = self.conn.get_public_node( + n=addr, ecdsa_curve_name=curve_name) log.debug('result: %s', result) return result.node.public_key