ssh: allow "just-in-time" connection for agent-like behaviour
This would allow launching trezor-agent into the background during the system startup, and the connecting the device when the cryptographic operations are required.
This commit is contained in:
@@ -110,12 +110,10 @@ def git_host(remote_name, attributes):
|
||||
return '{user}@{host}'.format(**match.groupdict())
|
||||
|
||||
|
||||
def run_server(conn, public_keys, command, debug, timeout):
|
||||
def run_server(conn, command, debug, timeout):
|
||||
"""Common code for run_agent and run_git below."""
|
||||
try:
|
||||
signer = conn.sign_ssh_challenge
|
||||
handler = protocol.Handler(keys=public_keys, signer=signer,
|
||||
debug=debug)
|
||||
handler = protocol.Handler(conn=conn, debug=debug)
|
||||
with server.serve(handler=handler, timeout=timeout) as env:
|
||||
return server.run_process(command=command, environ=env)
|
||||
except KeyboardInterrupt:
|
||||
@@ -142,13 +140,39 @@ def parse_config(fname):
|
||||
curve_name=curve_name)
|
||||
|
||||
|
||||
class JustInTimeConnection(object):
|
||||
"""Connect to the device just before the needed operation."""
|
||||
|
||||
def __init__(self, conn_factory, identities):
|
||||
"""Create a JIT connection object."""
|
||||
self.conn_factory = conn_factory
|
||||
self.identities = identities
|
||||
|
||||
def public_keys(self):
|
||||
"""Return a list of SSH public keys (in textual format)."""
|
||||
conn = self.conn_factory()
|
||||
return [conn.get_public_key(i) for i in self.identities]
|
||||
|
||||
def parse_public_keys(self):
|
||||
"""Parse SSH public keys into dictionaries."""
|
||||
public_keys = [formats.import_public_key(pk)
|
||||
for pk in self.public_keys()]
|
||||
for pk, identity in zip(public_keys, self.identities):
|
||||
pk['identity'] = identity
|
||||
return public_keys
|
||||
|
||||
def sign(self, blob, identity):
|
||||
"""Sign a given blob using the specified identity on the device."""
|
||||
conn = self.conn_factory()
|
||||
return conn.sign_ssh_challenge(blob=blob, identity=identity)
|
||||
|
||||
|
||||
@handle_connection_error
|
||||
def run_agent(client_factory=client.Client):
|
||||
"""Run ssh-agent using given hardware client factory."""
|
||||
args = create_agent_parser().parse_args()
|
||||
util.setup_logging(verbosity=args.verbose)
|
||||
|
||||
conn = client_factory(device=device.detect())
|
||||
if args.identity.startswith('/'):
|
||||
identities = list(parse_config(fname=args.identity))
|
||||
else:
|
||||
@@ -158,8 +182,6 @@ def run_agent(client_factory=client.Client):
|
||||
identity.identity_dict['proto'] = 'ssh'
|
||||
log.info('identity #%d: %s', index, identity)
|
||||
|
||||
public_keys = [conn.get_public_key(i) for i in identities]
|
||||
|
||||
if args.connect:
|
||||
command = ['ssh'] + ssh_args(args.identity) + args.command
|
||||
elif args.mosh:
|
||||
@@ -171,13 +193,12 @@ def run_agent(client_factory=client.Client):
|
||||
if use_shell:
|
||||
command = os.environ['SHELL']
|
||||
|
||||
if not command:
|
||||
for pk in public_keys:
|
||||
conn = JustInTimeConnection(
|
||||
conn_factory=lambda: client_factory(device.detect()),
|
||||
identities=identities)
|
||||
if command:
|
||||
return run_server(conn=conn, command=command, debug=args.debug,
|
||||
timeout=args.timeout)
|
||||
else:
|
||||
for pk in conn.public_keys():
|
||||
sys.stdout.write(pk)
|
||||
return
|
||||
|
||||
public_keys = [formats.import_public_key(pk) for pk in public_keys]
|
||||
for pk, identity in zip(public_keys, identities):
|
||||
pk['identity'] = identity
|
||||
return run_server(conn=conn, public_keys=public_keys, command=command,
|
||||
debug=args.debug, timeout=args.timeout)
|
||||
|
||||
@@ -71,14 +71,13 @@ def _legacy_pubs(buf):
|
||||
class Handler(object):
|
||||
"""ssh-agent protocol handler."""
|
||||
|
||||
def __init__(self, keys, signer, debug=False):
|
||||
def __init__(self, conn, debug=False):
|
||||
"""
|
||||
Create a protocol handler with specified public keys.
|
||||
|
||||
Use specified signer function to sign SSH authentication requests.
|
||||
"""
|
||||
self.public_keys = keys
|
||||
self.signer = signer
|
||||
self.conn = conn
|
||||
self.debug = debug
|
||||
|
||||
self.methods = {
|
||||
@@ -107,7 +106,7 @@ class Handler(object):
|
||||
def list_pubs(self, buf):
|
||||
"""SSH v2 public keys are serialized and returned."""
|
||||
assert not buf.read()
|
||||
keys = self.public_keys
|
||||
keys = self.conn.parse_public_keys()
|
||||
code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER'))
|
||||
num = util.pack('L', len(keys))
|
||||
log.debug('available keys: %s', [k['name'] for k in keys])
|
||||
@@ -129,7 +128,7 @@ class Handler(object):
|
||||
assert util.read_frame(buf) == b''
|
||||
assert not buf.read()
|
||||
|
||||
for k in self.public_keys:
|
||||
for k in self.conn.parse_public_keys():
|
||||
if (k['fingerprint']) == (key['fingerprint']):
|
||||
log.debug('using key %r (%s)', k['name'], k['fingerprint'])
|
||||
key = k
|
||||
@@ -140,7 +139,7 @@ class Handler(object):
|
||||
label = key['name'].decode('ascii') # label should be a string
|
||||
log.debug('signing %d-byte blob with "%s" key', len(blob), label)
|
||||
try:
|
||||
signature = self.signer(blob=blob, identity=key['identity'])
|
||||
signature = self.conn.sign(blob=blob, identity=key['identity'])
|
||||
except IOError:
|
||||
return failure()
|
||||
log.debug('signature: %r', signature)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .. import device, formats, protocol
|
||||
@@ -15,16 +16,23 @@ NIST256_SIGN_MSG = b'\r\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\
|
||||
NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
|
||||
|
||||
|
||||
def fake_connection(keys, signer):
|
||||
c = mock.Mock()
|
||||
c.parse_public_keys.return_value = keys
|
||||
c.sign = signer
|
||||
return c
|
||||
|
||||
|
||||
def test_list():
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(keys=[key], signer=None)
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=None))
|
||||
reply = h.handle(LIST_MSG)
|
||||
assert reply == LIST_NIST256_REPLY
|
||||
|
||||
|
||||
def test_unsupported():
|
||||
h = protocol.Handler(keys=[], signer=None)
|
||||
h = protocol.Handler(fake_connection(keys=[], signer=None))
|
||||
reply = h.handle(b'\x09')
|
||||
assert reply == b'\x00\x00\x00\x01\x05'
|
||||
|
||||
@@ -38,13 +46,13 @@ def ecdsa_signer(identity, blob):
|
||||
def test_ecdsa_sign():
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(keys=[key], signer=ecdsa_signer)
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=ecdsa_signer))
|
||||
reply = h.handle(NIST256_SIGN_MSG)
|
||||
assert reply == NIST256_SIGN_REPLY
|
||||
|
||||
|
||||
def test_sign_missing():
|
||||
h = protocol.Handler(keys=[], signer=ecdsa_signer)
|
||||
h = protocol.Handler(fake_connection(keys=[], signer=ecdsa_signer))
|
||||
with pytest.raises(KeyError):
|
||||
h.handle(NIST256_SIGN_MSG)
|
||||
|
||||
@@ -57,7 +65,7 @@ def test_sign_wrong():
|
||||
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(keys=[key], signer=wrong_signature)
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=wrong_signature))
|
||||
with pytest.raises(ValueError):
|
||||
h.handle(NIST256_SIGN_MSG)
|
||||
|
||||
@@ -68,7 +76,7 @@ def test_sign_cancel():
|
||||
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(keys=[key], signer=cancel_signature)
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=cancel_signature))
|
||||
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
|
||||
|
||||
|
||||
@@ -89,6 +97,6 @@ def ed25519_signer(identity, blob):
|
||||
def test_ed25519_sign():
|
||||
key = formats.import_public_key(ED25519_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519')
|
||||
h = protocol.Handler(keys=[key], signer=ed25519_signer)
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=ed25519_signer))
|
||||
reply = h.handle(ED25519_SIGN_MSG)
|
||||
assert reply == ED25519_SIGN_REPLY
|
||||
|
||||
@@ -37,10 +37,16 @@ class FakeSocket(object):
|
||||
pass
|
||||
|
||||
|
||||
def empty_device():
|
||||
c = mock.Mock(spec=['parse_public_keys'])
|
||||
c.parse_public_keys.return_value = []
|
||||
return c
|
||||
|
||||
|
||||
def test_handle():
|
||||
mutex = threading.Lock()
|
||||
|
||||
handler = protocol.Handler(keys=[], signer=None)
|
||||
handler = protocol.Handler(conn=empty_device())
|
||||
conn = FakeSocket()
|
||||
server.handle_connection(conn, handler, mutex)
|
||||
|
||||
@@ -67,7 +73,6 @@ def test_handle():
|
||||
|
||||
|
||||
def test_server_thread():
|
||||
|
||||
connections = [FakeSocket()]
|
||||
quit_event = threading.Event()
|
||||
|
||||
@@ -81,8 +86,10 @@ def test_server_thread():
|
||||
def getsockname(self): # pylint: disable=no-self-use
|
||||
return 'fake_server'
|
||||
|
||||
handler = protocol.Handler(keys=[], signer=None),
|
||||
handle_conn = functools.partial(server.handle_connection, handler=handler)
|
||||
handler = protocol.Handler(conn=empty_device()),
|
||||
handle_conn = functools.partial(server.handle_connection,
|
||||
handler=handler,
|
||||
mutex=None)
|
||||
server.server_thread(sock=FakeServer(),
|
||||
handle_conn=handle_conn,
|
||||
quit_event=quit_event)
|
||||
@@ -111,7 +118,7 @@ def test_run():
|
||||
|
||||
|
||||
def test_serve_main():
|
||||
handler = protocol.Handler(keys=[], signer=None)
|
||||
handler = protocol.Handler(conn=empty_device())
|
||||
with server.serve(handler=handler, sock_path=None):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user