Merge Trezor and KeepKey functionality
This commit is contained in:
@@ -5,7 +5,7 @@ python:
|
||||
- "3.4"
|
||||
|
||||
install:
|
||||
- pip install ecdsa ed25519 # test without trezorlib for now
|
||||
- pip install ecdsa ed25519 semver # test without trezorlib for now
|
||||
- pip install pylint coverage pep8
|
||||
|
||||
script:
|
||||
|
||||
2
setup.py
2
setup.py
@@ -9,7 +9,7 @@ setup(
|
||||
author_email='roman.zeyde@gmail.com',
|
||||
url='http://github.com/romanz/trezor-agent',
|
||||
packages=['trezor_agent', 'trezor_agent.trezor'],
|
||||
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0'],
|
||||
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0', 'semver>=2.2'],
|
||||
platforms=['POSIX'],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
1
tox.ini
1
tox.ini
@@ -7,6 +7,7 @@ deps=
|
||||
pep8
|
||||
coverage
|
||||
pylint
|
||||
semver
|
||||
commands=
|
||||
pep8 trezor_agent
|
||||
pylint --reports=no --rcfile .pylintrc trezor_agent
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import io
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .. import formats, util
|
||||
from ..trezor import client
|
||||
from ..trezor import client, factory
|
||||
|
||||
ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
|
||||
CURVE = 'nist256p1'
|
||||
@@ -18,15 +17,7 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
|
||||
|
||||
class ConnectionMock(object):
|
||||
|
||||
def __init__(self, version):
|
||||
self.features = mock.Mock(spec=[])
|
||||
self.features.device_id = '123456789'
|
||||
self.features.label = 'mywallet'
|
||||
self.features.vendor = 'mock'
|
||||
self.features.major_version = version[0]
|
||||
self.features.minor_version = version[1]
|
||||
self.features.patch_version = version[2]
|
||||
self.features.revision = b'456'
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
@@ -49,21 +40,20 @@ class ConnectionMock(object):
|
||||
return msg
|
||||
|
||||
|
||||
class FactoryMock(object):
|
||||
def identity_type(**kwargs):
|
||||
result = mock.Mock(spec=[])
|
||||
result.index = 0
|
||||
result.proto = result.user = result.host = result.port = None
|
||||
result.path = None
|
||||
for k, v in kwargs.items():
|
||||
setattr(result, k, v)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def client():
|
||||
return ConnectionMock(version=(1, 3, 4))
|
||||
|
||||
@staticmethod
|
||||
def identity_type(**kwargs):
|
||||
result = mock.Mock(spec=[])
|
||||
result.index = 0
|
||||
result.proto = result.user = result.host = result.port = None
|
||||
result.path = None
|
||||
for k, v in kwargs.items():
|
||||
setattr(result, k, v)
|
||||
return result
|
||||
def load_client():
|
||||
return factory.ClientWrapper(connection=ConnectionMock(),
|
||||
identity_type=identity_type,
|
||||
device_name='DEVICE_NAME')
|
||||
|
||||
|
||||
BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
|
||||
@@ -82,7 +72,7 @@ SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
|
||||
|
||||
def test_ssh_agent():
|
||||
label = 'localhost:22'
|
||||
c = client.Client(factory=FactoryMock)
|
||||
c = client.Client(loader=load_client)
|
||||
ident = c.get_identity(label=label)
|
||||
assert ident.host == 'localhost'
|
||||
assert ident.proto == 'ssh'
|
||||
@@ -129,15 +119,3 @@ def test_utils():
|
||||
|
||||
url = 'https://user@host:443/path'
|
||||
assert client.identity_to_string(identity) == url
|
||||
|
||||
|
||||
def test_old_version():
|
||||
|
||||
class OldFactoryMock(FactoryMock):
|
||||
|
||||
@staticmethod
|
||||
def client():
|
||||
return ConnectionMock(version=(1, 2, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client.Client(factory=OldFactoryMock)
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
''' Thin wrapper around trezorlib. '''
|
||||
|
||||
|
||||
def client():
|
||||
# pylint: disable=import-error
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport as TrezorHidTransport
|
||||
from trezorlib.messages_pb2 import PassphraseAck as TrezorPassphraseAck
|
||||
|
||||
from keepkeylib.client import KeepKeyClient
|
||||
from keepkeylib.transport_hid import HidTransport as KeepKeyHidTransport
|
||||
from keepkeylib.messages_pb2 import PassphraseAck as KeepKeyPassphraseAck
|
||||
|
||||
devices = list(TrezorHidTransport.enumerate())
|
||||
if len(devices) == 1:
|
||||
t = TrezorClient(TrezorHidTransport(devices[0]))
|
||||
t.callback_PassphraseRequest = lambda msg: TrezorPassphraseAck(passphrase='')
|
||||
else:
|
||||
devices = list(KeepKeyHidTransport.enumerate())
|
||||
if len(devices) != 1:
|
||||
msg = '{:d} devices found'.format(len(devices))
|
||||
raise IOError(msg)
|
||||
t = KeepKeyClient(KeepKeyHidTransport(devices[0]))
|
||||
t.callback_PassphraseRequest = lambda msg: KeepKeyPassphraseAck(passphrase='')
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def trezor_identity_type(**kwargs):
|
||||
# pylint: disable=import-error
|
||||
from trezorlib.types_pb2 import IdentityType
|
||||
return IdentityType(**kwargs)
|
||||
|
||||
def keepkey_identity_type(**kwargs):
|
||||
# pylint: disable=import-error
|
||||
from keepkeylib.types_pb2 import IdentityType
|
||||
return IdentityType(**kwargs)
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import re
|
||||
import struct
|
||||
|
||||
from . import _factory as Factory
|
||||
from . import factory
|
||||
from .. import formats, util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -12,29 +12,12 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class Client(object):
|
||||
|
||||
TREZOR_MIN_VERSION = [1, 3, 4]
|
||||
KEEPKEY_MIN_VERSION = [1, 0, 4]
|
||||
|
||||
def __init__(self, factory=Factory, curve=formats.CURVE_NIST256):
|
||||
def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256):
|
||||
client_wrapper = loader()
|
||||
self.client = client_wrapper.connection
|
||||
self.identity_type = client_wrapper.identity_type
|
||||
self.device_name = client_wrapper.device_name
|
||||
self.curve = curve
|
||||
self.factory = factory
|
||||
self.client = self.factory.client()
|
||||
f = self.client.features
|
||||
log.debug('connected to Trezor %s', f.device_id)
|
||||
log.debug('label : %s', f.label)
|
||||
log.debug('vendor : %s', f.vendor)
|
||||
version = [f.major_version, f.minor_version, f.patch_version]
|
||||
version_str = '.'.join([str(v) for v in version])
|
||||
log.debug('version : %s', version_str)
|
||||
log.debug('revision : %s', binascii.hexlify(f.revision))
|
||||
if f.vendor == 'bitcointrezor.com' and version < self.TREZOR_MIN_VERSION:
|
||||
fmt = 'Please upgrade your TREZOR to v{}+ firmware'
|
||||
version_str = '.'.join([str(v) for v in self.TREZOR_MIN_VERSION])
|
||||
raise ValueError(fmt.format(version_str))
|
||||
elif f.vendor == 'keepkey.com' and version < self.KEEPKEY_MIN_VERSION:
|
||||
fmt = 'Please upgrade your KEEPKEY to v{}+ firmware'
|
||||
version_str = '.'.join([str(v) for v in self.KEEPKEY_MIN_VERSION])
|
||||
raise ValueError(fmt.format(version_str))
|
||||
|
||||
def __enter__(self):
|
||||
msg = 'Hello World!'
|
||||
@@ -42,24 +25,20 @@ class Client(object):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
log.info('disconnected from Trezor')
|
||||
log.info('disconnected from %s', self.device_name)
|
||||
self.client.clear_session() # forget PIN and shutdown screen
|
||||
self.client.close()
|
||||
|
||||
def get_identity(self, label):
|
||||
identity = string_to_identity(label, self.factory.trezor_identity_type)
|
||||
|
||||
if self.client.features.vendor == 'keepkey.com':
|
||||
identity = string_to_identity(label, self.factory.keepkey_identity_type)
|
||||
|
||||
identity = string_to_identity(label, self.identity_type)
|
||||
identity.proto = 'ssh'
|
||||
return identity
|
||||
|
||||
def get_public_key(self, label):
|
||||
identity = self.get_identity(label=label)
|
||||
label = identity_to_string(identity) # canonize key label
|
||||
log.info('getting "%s" public key (%s) from Trezor...',
|
||||
label, self.curve)
|
||||
log.info('getting "%s" public key (%s) from %s...',
|
||||
label, self.curve, self.device_name)
|
||||
addr = _get_address(identity)
|
||||
node = self.client.get_public_node(n=addr,
|
||||
ecdsa_curve_name=self.curve)
|
||||
@@ -72,8 +51,8 @@ class Client(object):
|
||||
identity = self.get_identity(label=label)
|
||||
msg = _parse_ssh_blob(blob)
|
||||
|
||||
log.info('please confirm user "%s" login to "%s" using Trezor...',
|
||||
msg['user'], label)
|
||||
log.info('please confirm user "%s" login to "%s" using %s...',
|
||||
msg['user'], label, self.device_name)
|
||||
|
||||
visual = identity.path # not signed when proto='ssh'
|
||||
result = self.client.sign_identity(identity=identity,
|
||||
|
||||
78
trezor_agent/trezor/factory.py
Normal file
78
trezor_agent/trezor/factory.py
Normal file
@@ -0,0 +1,78 @@
|
||||
''' Thin wrapper around trezor/keepkey libraries. '''
|
||||
import binascii
|
||||
import collections
|
||||
import logging
|
||||
|
||||
import semver
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ClientWrapper = collections.namedtuple(
|
||||
'ClientWrapper',
|
||||
['connection', 'identity_type', 'device_name'])
|
||||
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def _load_client(name, client_type, hid_transport,
|
||||
passphrase_ack, identity_type, required_version):
|
||||
|
||||
def empty_passphrase_handler(_):
|
||||
return passphrase_ack(passphrase='')
|
||||
|
||||
for d in hid_transport.enumerate():
|
||||
connection = client_type(hid_transport(d))
|
||||
connection.callback_PassphraseRequest = empty_passphrase_handler
|
||||
f = connection.features
|
||||
log.debug('connected to %s %s', name, f.device_id)
|
||||
log.debug('label : %s', f.label)
|
||||
log.debug('vendor : %s', f.vendor)
|
||||
current_version = '{}.{}.{}'.format(f.major_version,
|
||||
f.minor_version,
|
||||
f.patch_version)
|
||||
log.debug('version : %s', current_version)
|
||||
log.debug('revision : %s', binascii.hexlify(f.revision))
|
||||
if not semver.match(current_version, required_version):
|
||||
fmt = 'Please upgrade your {} firmware to {} version (current: {})'
|
||||
raise ValueError(fmt.format(name,
|
||||
required_version,
|
||||
current_version))
|
||||
yield ClientWrapper(connection=connection,
|
||||
identity_type=identity_type,
|
||||
device_name=name)
|
||||
|
||||
|
||||
def _load_trezor():
|
||||
# pylint: disable=import-error
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport
|
||||
from trezorlib.messages_pb2 import PassphraseAck
|
||||
from trezorlib.types_pb2 import IdentityType
|
||||
return _load_client(name='Trezor',
|
||||
client_type=TrezorClient,
|
||||
hid_transport=HidTransport,
|
||||
passphrase_ack=PassphraseAck,
|
||||
identity_type=IdentityType,
|
||||
required_version='>=1.3.4')
|
||||
|
||||
|
||||
def _load_keepkey():
|
||||
# pylint: disable=import-error
|
||||
from keepkeylib.client import KeepKeyClient
|
||||
from keepkeylib.transport_hid import HidTransport
|
||||
from keepkeylib.messages_pb2 import PassphraseAck
|
||||
from keepkeylib.types_pb2 import IdentityType
|
||||
return _load_client(name='KeepKey',
|
||||
client_type=KeepKeyClient,
|
||||
hid_transport=HidTransport,
|
||||
passphrase_ack=PassphraseAck,
|
||||
identity_type=IdentityType,
|
||||
required_version='>=1.0.4')
|
||||
|
||||
|
||||
def load():
|
||||
devices = list(_load_trezor()) + list(_load_keepkey())
|
||||
if len(devices) == 1:
|
||||
return devices[0]
|
||||
|
||||
msg = '{:d} devices found'.format(len(devices))
|
||||
raise IOError(msg)
|
||||
Reference in New Issue
Block a user