trezor: update tests and remove identity issues

This commit is contained in:
Roman Zeyde
2015-08-18 18:52:29 +03:00
parent bf6b58971a
commit 592bc78391
4 changed files with 25 additions and 89 deletions

View File

@@ -24,7 +24,6 @@ setup(
'Topic :: Communications', 'Topic :: Communications',
], ],
entry_points={'console_scripts': [ entry_points={'console_scripts': [
'trezor-agent = sshagent.__main__:trezor_agent', 'trezor-agent = sshagent.__main__:trezor_agent'
'trezor-verify = sshagent.__main__:trezor_verify'
]}, ]},
) )

View File

@@ -114,18 +114,3 @@ def trezor_agent():
use_shell=use_shell) use_shell=use_shell)
except KeyboardInterrupt: except KeyboardInterrupt:
log.info('server stopped') log.info('server stopped')
def trezor_verify():
p = argparse.ArgumentParser()
p.add_argument('-v', '--verbose', default=0, action='count')
p.add_argument('address', nargs='?', default=None)
args = p.parse_args()
setup_logging(verbosity=args.verbose)
host = subprocess.check_output('hostname')
label = '{}'.format(host)
with trezor.Client() as client:
return client.sign_identity(label=label,
expected_address=args.address)

View File

@@ -33,10 +33,10 @@ class ConnectionMock(object):
def clear_session(self): def clear_session(self):
self.closed = True self.closed = True
def get_public_node(self, n, ecdsa_curve_name): def get_public_node(self, n, ecdsa_curve_name='secp256k1'):
assert not self.closed assert not self.closed
assert n == ADDR assert n == ADDR
assert ecdsa_curve_name == CURVE assert ecdsa_curve_name in {'secp256k1', 'nist256p1'}
result = mock.Mock(spec=[]) result = mock.Mock(spec=[])
result.node = mock.Mock(spec=[]) result.node = mock.Mock(spec=[])
result.node.public_key = PUBKEY result.node.public_key = PUBKEY
@@ -74,7 +74,7 @@ SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2') b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2')
def test_client(): def test_ssh_agent():
c = client.Client(factory=FactoryMock) c = client.Client(factory=FactoryMock)
ident = c.get_identity(label='localhost:22', protocol='ssh') ident = c.get_identity(label='localhost:22', protocol='ssh')
assert ident.host == 'localhost' assert ident.host == 'localhost'
@@ -86,8 +86,8 @@ def test_client():
with c: with c:
assert c.get_public_key(ident) == PUBKEY_TEXT assert c.get_public_key(ident) == PUBKEY_TEXT
def _sign_identity(identity, challenge_hidden, def ssh_sign_identity(identity, challenge_hidden,
challenge_visual, ecdsa_curve_name): challenge_visual, ecdsa_curve_name):
assert identity is ident assert identity is ident
assert challenge_hidden == BLOB assert challenge_hidden == BLOB
assert challenge_visual == identity.path assert challenge_visual == identity.path
@@ -98,9 +98,21 @@ def test_client():
result.signature = SIG result.signature = SIG
return result return result
c.client.sign_identity = _sign_identity c.client.sign_identity = ssh_sign_identity
signature = c.sign_ssh_challenge(identity=ident, blob=BLOB) signature = c.sign_ssh_challenge(identity=ident, blob=BLOB)
key = formats.import_public_key(PUBKEY_TEXT) key = formats.import_public_key(PUBKEY_TEXT)
assert key['verifying_key'].verify(signature=signature, data=BLOB, assert key['verifying_key'].verify(signature=signature, data=BLOB,
sigdecode=lambda sig, _: sig) sigdecode=lambda sig, _: sig)
def test_utils():
identity = mock.Mock(spec=[])
identity.proto = 'https'
identity.user = 'user'
identity.host = 'host'
identity.port = '443'
identity.path = '/path'
url = 'https://user@host:443/path'
assert client.identity_to_string(identity) == url

View File

@@ -2,8 +2,6 @@ import io
import re import re
import struct import struct
import binascii import binascii
import time
import os
from .. import util from .. import util
from .. import formats from .. import formats
@@ -35,7 +33,7 @@ class Client(object):
self.client.close() self.client.close()
def get_identity(self, label, protocol=None): def get_identity(self, label, protocol=None):
identity = _string_to_identity(label, self.factory.identity_type) identity = string_to_identity(label, self.factory.identity_type)
if protocol is not None: if protocol is not None:
identity.proto = protocol identity.proto = protocol
@@ -43,7 +41,7 @@ class Client(object):
def get_public_key(self, identity): def get_public_key(self, identity):
assert identity.proto == 'ssh' assert identity.proto == 'ssh'
label = _identity_to_string(identity) label = identity_to_string(identity)
log.info('getting "%s" public key from Trezor...', label) log.info('getting "%s" public key from Trezor...', label)
addr = _get_address(identity) addr = _get_address(identity)
node = self.client.get_public_node(n=addr, node = self.client.get_public_node(n=addr,
@@ -54,7 +52,7 @@ class Client(object):
def sign_ssh_challenge(self, identity, blob): def sign_ssh_challenge(self, identity, blob):
assert identity.proto == 'ssh' assert identity.proto == 'ssh'
label = _identity_to_string(identity) label = identity_to_string(identity)
msg = _parse_ssh_blob(blob) msg = _parse_ssh_blob(blob)
log.info('please confirm user "%s" login to "%s" using Trezor...', log.info('please confirm user "%s" login to "%s" using Trezor...',
@@ -73,57 +71,6 @@ class Client(object):
return parse_signature(result.signature) return parse_signature(result.signature)
def sign_identity(self, label, expected_address=None,
_strftime=time.strftime, _urandom=os.urandom):
from bitcoin import pubkey_to_address
visual = _strftime('%d/%m/%y %H:%M:%S')
hidden = _urandom(64)
identity = self.get_identity(label=label)
derivation_path = _get_address(identity)
node = self.client.get_public_node(derivation_path)
address = pubkey_to_address(node.node.public_key)
log.info('address: %s', address)
if expected_address is None:
log.warning('Specify Bitcoin address: %s', address)
self.client.get_address(n=derivation_path,
coin_name='Bitcoin',
show_display=True)
return 2
assert expected_address == address
result = self.client.sign_identity(identity=identity,
challenge_hidden=hidden,
challenge_visual=visual)
assert address == result.address
assert node.node.public_key == result.public_key
digest = message_digest(hidden=hidden, visual=visual)
return _validate_signature(result=result, digest=digest)
def _validate_signature(result, digest, curve=formats.ecdsa.SECP256k1):
verifying_key = formats.decompress_pubkey(result.public_key,
curve=curve)
log.debug('digest: %s', binascii.hexlify(digest))
signature = parse_signature(result.signature)
log.debug('signature: %s', signature)
try:
verifying_key.verify_digest(signature=signature,
digest=digest,
sigdecode=lambda sig, _: sig)
except formats.ecdsa.BadSignatureError:
log.error('signature: ERROR')
return 1
log.info('signature: OK')
return 0
def parse_signature(blob): def parse_signature(blob):
sig = blob[1:] sig = blob[1:]
@@ -132,13 +79,6 @@ def parse_signature(blob):
return (r, s) return (r, s)
def message_digest(hidden, visual):
from bitcoin import electrum_sig_hash
hidden_digest = formats.hashfunc(hidden).digest()
visual_digest = formats.hashfunc(visual).digest()
return electrum_sig_hash(hidden_digest + visual_digest)
_identity_regexp = re.compile(''.join([ _identity_regexp = re.compile(''.join([
'^' '^'
r'(?:(?P<proto>.*)://)?', r'(?:(?P<proto>.*)://)?',
@@ -150,7 +90,7 @@ _identity_regexp = re.compile(''.join([
])) ]))
def _string_to_identity(s, identity_type): def string_to_identity(s, identity_type):
m = _identity_regexp.match(s) m = _identity_regexp.match(s)
result = m.groupdict() result = m.groupdict()
log.debug('parsed identity: %s', result) log.debug('parsed identity: %s', result)
@@ -158,7 +98,7 @@ def _string_to_identity(s, identity_type):
return identity_type(**kwargs) return identity_type(**kwargs)
def _identity_to_string(identity): def identity_to_string(identity):
result = [] result = []
if identity.proto: if identity.proto:
result.append(identity.proto + '://') result.append(identity.proto + '://')
@@ -174,7 +114,7 @@ def _identity_to_string(identity):
def _get_address(identity): def _get_address(identity):
index = struct.pack('<L', identity.index) index = struct.pack('<L', identity.index)
addr = index + _identity_to_string(identity).encode('ascii') addr = index + identity_to_string(identity).encode('ascii')
log.debug('address string: %r', addr) log.debug('address string: %r', addr)
digest = formats.hashfunc(addr).digest() digest = formats.hashfunc(addr).digest()
s = io.BytesIO(bytearray(digest)) s = io.BytesIO(bytearray(digest))