trezor: update tests and remove identity issues
This commit is contained in:
3
setup.py
3
setup.py
@@ -24,7 +24,6 @@ setup(
|
|||||||
'Topic :: Communications',
|
'Topic :: Communications',
|
||||||
],
|
],
|
||||||
entry_points={'console_scripts': [
|
entry_points={'console_scripts': [
|
||||||
'trezor-agent = sshagent.__main__:trezor_agent',
|
'trezor-agent = sshagent.__main__:trezor_agent'
|
||||||
'trezor-verify = sshagent.__main__:trezor_verify'
|
|
||||||
]},
|
]},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -114,18 +114,3 @@ def trezor_agent():
|
|||||||
use_shell=use_shell)
|
use_shell=use_shell)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.info('server stopped')
|
log.info('server stopped')
|
||||||
|
|
||||||
|
|
||||||
def trezor_verify():
|
|
||||||
|
|
||||||
p = argparse.ArgumentParser()
|
|
||||||
p.add_argument('-v', '--verbose', default=0, action='count')
|
|
||||||
p.add_argument('address', nargs='?', default=None)
|
|
||||||
args = p.parse_args()
|
|
||||||
|
|
||||||
setup_logging(verbosity=args.verbose)
|
|
||||||
host = subprocess.check_output('hostname')
|
|
||||||
label = '{}'.format(host)
|
|
||||||
with trezor.Client() as client:
|
|
||||||
return client.sign_identity(label=label,
|
|
||||||
expected_address=args.address)
|
|
||||||
|
|||||||
@@ -33,10 +33,10 @@ class ConnectionMock(object):
|
|||||||
def clear_session(self):
|
def clear_session(self):
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
def get_public_node(self, n, ecdsa_curve_name):
|
def get_public_node(self, n, ecdsa_curve_name='secp256k1'):
|
||||||
assert not self.closed
|
assert not self.closed
|
||||||
assert n == ADDR
|
assert n == ADDR
|
||||||
assert ecdsa_curve_name == CURVE
|
assert ecdsa_curve_name in {'secp256k1', 'nist256p1'}
|
||||||
result = mock.Mock(spec=[])
|
result = mock.Mock(spec=[])
|
||||||
result.node = mock.Mock(spec=[])
|
result.node = mock.Mock(spec=[])
|
||||||
result.node.public_key = PUBKEY
|
result.node.public_key = PUBKEY
|
||||||
@@ -74,7 +74,7 @@ SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
|
|||||||
b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2')
|
b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2')
|
||||||
|
|
||||||
|
|
||||||
def test_client():
|
def test_ssh_agent():
|
||||||
c = client.Client(factory=FactoryMock)
|
c = client.Client(factory=FactoryMock)
|
||||||
ident = c.get_identity(label='localhost:22', protocol='ssh')
|
ident = c.get_identity(label='localhost:22', protocol='ssh')
|
||||||
assert ident.host == 'localhost'
|
assert ident.host == 'localhost'
|
||||||
@@ -86,8 +86,8 @@ def test_client():
|
|||||||
with c:
|
with c:
|
||||||
assert c.get_public_key(ident) == PUBKEY_TEXT
|
assert c.get_public_key(ident) == PUBKEY_TEXT
|
||||||
|
|
||||||
def _sign_identity(identity, challenge_hidden,
|
def ssh_sign_identity(identity, challenge_hidden,
|
||||||
challenge_visual, ecdsa_curve_name):
|
challenge_visual, ecdsa_curve_name):
|
||||||
assert identity is ident
|
assert identity is ident
|
||||||
assert challenge_hidden == BLOB
|
assert challenge_hidden == BLOB
|
||||||
assert challenge_visual == identity.path
|
assert challenge_visual == identity.path
|
||||||
@@ -98,9 +98,21 @@ def test_client():
|
|||||||
result.signature = SIG
|
result.signature = SIG
|
||||||
return result
|
return result
|
||||||
|
|
||||||
c.client.sign_identity = _sign_identity
|
c.client.sign_identity = ssh_sign_identity
|
||||||
signature = c.sign_ssh_challenge(identity=ident, blob=BLOB)
|
signature = c.sign_ssh_challenge(identity=ident, blob=BLOB)
|
||||||
|
|
||||||
key = formats.import_public_key(PUBKEY_TEXT)
|
key = formats.import_public_key(PUBKEY_TEXT)
|
||||||
assert key['verifying_key'].verify(signature=signature, data=BLOB,
|
assert key['verifying_key'].verify(signature=signature, data=BLOB,
|
||||||
sigdecode=lambda sig, _: sig)
|
sigdecode=lambda sig, _: sig)
|
||||||
|
|
||||||
|
|
||||||
|
def test_utils():
|
||||||
|
identity = mock.Mock(spec=[])
|
||||||
|
identity.proto = 'https'
|
||||||
|
identity.user = 'user'
|
||||||
|
identity.host = 'host'
|
||||||
|
identity.port = '443'
|
||||||
|
identity.path = '/path'
|
||||||
|
|
||||||
|
url = 'https://user@host:443/path'
|
||||||
|
assert client.identity_to_string(identity) == url
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ import io
|
|||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
import binascii
|
import binascii
|
||||||
import time
|
|
||||||
import os
|
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from .. import formats
|
from .. import formats
|
||||||
@@ -35,7 +33,7 @@ class Client(object):
|
|||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
def get_identity(self, label, protocol=None):
|
def get_identity(self, label, protocol=None):
|
||||||
identity = _string_to_identity(label, self.factory.identity_type)
|
identity = string_to_identity(label, self.factory.identity_type)
|
||||||
if protocol is not None:
|
if protocol is not None:
|
||||||
identity.proto = protocol
|
identity.proto = protocol
|
||||||
|
|
||||||
@@ -43,7 +41,7 @@ class Client(object):
|
|||||||
|
|
||||||
def get_public_key(self, identity):
|
def get_public_key(self, identity):
|
||||||
assert identity.proto == 'ssh'
|
assert identity.proto == 'ssh'
|
||||||
label = _identity_to_string(identity)
|
label = identity_to_string(identity)
|
||||||
log.info('getting "%s" public key from Trezor...', label)
|
log.info('getting "%s" public key from Trezor...', label)
|
||||||
addr = _get_address(identity)
|
addr = _get_address(identity)
|
||||||
node = self.client.get_public_node(n=addr,
|
node = self.client.get_public_node(n=addr,
|
||||||
@@ -54,7 +52,7 @@ class Client(object):
|
|||||||
|
|
||||||
def sign_ssh_challenge(self, identity, blob):
|
def sign_ssh_challenge(self, identity, blob):
|
||||||
assert identity.proto == 'ssh'
|
assert identity.proto == 'ssh'
|
||||||
label = _identity_to_string(identity)
|
label = identity_to_string(identity)
|
||||||
msg = _parse_ssh_blob(blob)
|
msg = _parse_ssh_blob(blob)
|
||||||
|
|
||||||
log.info('please confirm user "%s" login to "%s" using Trezor...',
|
log.info('please confirm user "%s" login to "%s" using Trezor...',
|
||||||
@@ -73,57 +71,6 @@ class Client(object):
|
|||||||
|
|
||||||
return parse_signature(result.signature)
|
return parse_signature(result.signature)
|
||||||
|
|
||||||
def sign_identity(self, label, expected_address=None,
|
|
||||||
_strftime=time.strftime, _urandom=os.urandom):
|
|
||||||
from bitcoin import pubkey_to_address
|
|
||||||
|
|
||||||
visual = _strftime('%d/%m/%y %H:%M:%S')
|
|
||||||
hidden = _urandom(64)
|
|
||||||
identity = self.get_identity(label=label)
|
|
||||||
|
|
||||||
derivation_path = _get_address(identity)
|
|
||||||
node = self.client.get_public_node(derivation_path)
|
|
||||||
address = pubkey_to_address(node.node.public_key)
|
|
||||||
log.info('address: %s', address)
|
|
||||||
|
|
||||||
if expected_address is None:
|
|
||||||
log.warning('Specify Bitcoin address: %s', address)
|
|
||||||
self.client.get_address(n=derivation_path,
|
|
||||||
coin_name='Bitcoin',
|
|
||||||
show_display=True)
|
|
||||||
return 2
|
|
||||||
|
|
||||||
assert expected_address == address
|
|
||||||
|
|
||||||
result = self.client.sign_identity(identity=identity,
|
|
||||||
challenge_hidden=hidden,
|
|
||||||
challenge_visual=visual)
|
|
||||||
|
|
||||||
assert address == result.address
|
|
||||||
assert node.node.public_key == result.public_key
|
|
||||||
|
|
||||||
digest = message_digest(hidden=hidden, visual=visual)
|
|
||||||
return _validate_signature(result=result, digest=digest)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_signature(result, digest, curve=formats.ecdsa.SECP256k1):
|
|
||||||
verifying_key = formats.decompress_pubkey(result.public_key,
|
|
||||||
curve=curve)
|
|
||||||
|
|
||||||
log.debug('digest: %s', binascii.hexlify(digest))
|
|
||||||
signature = parse_signature(result.signature)
|
|
||||||
log.debug('signature: %s', signature)
|
|
||||||
try:
|
|
||||||
verifying_key.verify_digest(signature=signature,
|
|
||||||
digest=digest,
|
|
||||||
sigdecode=lambda sig, _: sig)
|
|
||||||
except formats.ecdsa.BadSignatureError:
|
|
||||||
log.error('signature: ERROR')
|
|
||||||
return 1
|
|
||||||
|
|
||||||
log.info('signature: OK')
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def parse_signature(blob):
|
def parse_signature(blob):
|
||||||
sig = blob[1:]
|
sig = blob[1:]
|
||||||
@@ -132,13 +79,6 @@ def parse_signature(blob):
|
|||||||
return (r, s)
|
return (r, s)
|
||||||
|
|
||||||
|
|
||||||
def message_digest(hidden, visual):
|
|
||||||
from bitcoin import electrum_sig_hash
|
|
||||||
hidden_digest = formats.hashfunc(hidden).digest()
|
|
||||||
visual_digest = formats.hashfunc(visual).digest()
|
|
||||||
return electrum_sig_hash(hidden_digest + visual_digest)
|
|
||||||
|
|
||||||
|
|
||||||
_identity_regexp = re.compile(''.join([
|
_identity_regexp = re.compile(''.join([
|
||||||
'^'
|
'^'
|
||||||
r'(?:(?P<proto>.*)://)?',
|
r'(?:(?P<proto>.*)://)?',
|
||||||
@@ -150,7 +90,7 @@ _identity_regexp = re.compile(''.join([
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
|
|
||||||
def _string_to_identity(s, identity_type):
|
def string_to_identity(s, identity_type):
|
||||||
m = _identity_regexp.match(s)
|
m = _identity_regexp.match(s)
|
||||||
result = m.groupdict()
|
result = m.groupdict()
|
||||||
log.debug('parsed identity: %s', result)
|
log.debug('parsed identity: %s', result)
|
||||||
@@ -158,7 +98,7 @@ def _string_to_identity(s, identity_type):
|
|||||||
return identity_type(**kwargs)
|
return identity_type(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _identity_to_string(identity):
|
def identity_to_string(identity):
|
||||||
result = []
|
result = []
|
||||||
if identity.proto:
|
if identity.proto:
|
||||||
result.append(identity.proto + '://')
|
result.append(identity.proto + '://')
|
||||||
@@ -174,7 +114,7 @@ def _identity_to_string(identity):
|
|||||||
|
|
||||||
def _get_address(identity):
|
def _get_address(identity):
|
||||||
index = struct.pack('<L', identity.index)
|
index = struct.pack('<L', identity.index)
|
||||||
addr = index + _identity_to_string(identity).encode('ascii')
|
addr = index + identity_to_string(identity).encode('ascii')
|
||||||
log.debug('address string: %r', addr)
|
log.debug('address string: %r', addr)
|
||||||
digest = formats.hashfunc(addr).digest()
|
digest = formats.hashfunc(addr).digest()
|
||||||
s = io.BytesIO(bytearray(digest))
|
s = io.BytesIO(bytearray(digest))
|
||||||
|
|||||||
Reference in New Issue
Block a user