diff --git a/trezor_agent/protocol.py b/trezor_agent/protocol.py index fa90c01..9418224 100644 --- a/trezor_agent/protocol.py +++ b/trezor_agent/protocol.py @@ -15,13 +15,57 @@ from . import formats, util log = logging.getLogger(__name__) -SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1 -SSH_AGENT_RSA_IDENTITIES_ANSWER = 2 -SSH2_AGENTC_REQUEST_IDENTITIES = 11 -SSH2_AGENT_IDENTITIES_ANSWER = 12 -SSH2_AGENTC_SIGN_REQUEST = 13 -SSH2_AGENT_SIGN_RESPONSE = 14 +# Taken from https://github.com/openssh/openssh-portable/blob/master/authfd.h +COMMANDS = dict( + SSH_AGENTC_REQUEST_RSA_IDENTITIES=1, + SSH_AGENT_RSA_IDENTITIES_ANSWER=2, + SSH_AGENTC_RSA_CHALLENGE=3, + SSH_AGENT_RSA_RESPONSE=4, + SSH_AGENT_FAILURE=5, + SSH_AGENT_SUCCESS=6, + SSH_AGENTC_ADD_RSA_IDENTITY=7, + SSH_AGENTC_REMOVE_RSA_IDENTITY=8, + SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES=9, + SSH2_AGENTC_REQUEST_IDENTITIES=11, + SSH2_AGENT_IDENTITIES_ANSWER=12, + SSH2_AGENTC_SIGN_REQUEST=13, + SSH2_AGENT_SIGN_RESPONSE=14, + SSH2_AGENTC_ADD_IDENTITY=17, + SSH2_AGENTC_REMOVE_IDENTITY=18, + SSH2_AGENTC_REMOVE_ALL_IDENTITIES=19, + SSH_AGENTC_ADD_SMARTCARD_KEY=20, + SSH_AGENTC_REMOVE_SMARTCARD_KEY=21, + SSH_AGENTC_LOCK=22, + SSH_AGENTC_UNLOCK=23, + SSH_AGENTC_ADD_RSA_ID_CONSTRAINED=24, + SSH2_AGENTC_ADD_ID_CONSTRAINED=25, + SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED=26, +) + + +def msg_code(name): + """Convert string name into a integer message code.""" + return COMMANDS[name] + + +def msg_name(code): + """Convert integer message code into a string name.""" + ids = {v: k for k, v in COMMANDS.items()} + return ids[code] + + +def _fail(): + error_msg = util.pack('B', msg_code('SSH_AGENT_FAILURE')) + return util.frame(error_msg) + + +def _legacy_pubs(buf): + """SSH v1 public keys are not supported.""" + assert not buf.read() + code = util.pack('B', msg_code('SSH_AGENT_RSA_IDENTITIES_ANSWER')) + num = util.pack('L', 0) # no SSH v1 keys + return util.frame(code, num) class Handler(object): @@ -38,9 +82,9 @@ class Handler(object): self.debug = debug self.methods = { - SSH_AGENTC_REQUEST_RSA_IDENTITIES: Handler.legacy_pubs, - SSH2_AGENTC_REQUEST_IDENTITIES: self.list_pubs, - SSH2_AGENTC_SIGN_REQUEST: self.sign_message, + msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES'): _legacy_pubs, + msg_code('SSH2_AGENTC_REQUEST_IDENTITIES'): self.list_pubs, + msg_code('SSH2_AGENTC_SIGN_REQUEST'): self.sign_message, } def handle(self, msg): @@ -49,6 +93,10 @@ class Handler(object): log.debug('request: %d bytes%s', len(msg), debug_msg) buf = io.BytesIO(msg) code, = util.recv(buf, '>B') + if code not in self.methods: + log.warning('Unsupported command: %s (%d)', msg_name(code), code) + return _fail() + method = self.methods[code] log.debug('calling %s()', method.__name__) reply = method(buf=buf) @@ -56,19 +104,11 @@ class Handler(object): log.debug('reply: %d bytes%s', len(reply), debug_reply) return reply - @staticmethod - def legacy_pubs(buf): - """SSH v1 public keys are not supported.""" - assert not buf.read() - code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER) - num = util.pack('L', 0) # no SSH v1 keys - return util.frame(code, num) - def list_pubs(self, buf): """SSH v2 public keys are serialized and returned.""" assert not buf.read() keys = self.public_keys - code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER) + code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER')) num = util.pack('L', len(keys)) log.debug('available keys: %s', [k['name'] for k in keys]) for i, k in enumerate(keys): @@ -112,5 +152,5 @@ class Handler(object): log.debug('signature size: %d bytes', len(sig_bytes)) data = util.frame(util.frame(key['type']), util.frame(sig_bytes)) - code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE) + code = util.pack('B', msg_code('SSH2_AGENT_SIGN_RESPONSE')) return util.frame(code, data) diff --git a/trezor_agent/tests/test_protocol.py b/trezor_agent/tests/test_protocol.py index e923bc5..f3578cb 100644 --- a/trezor_agent/tests/test_protocol.py +++ b/trezor_agent/tests/test_protocol.py @@ -22,6 +22,12 @@ def test_list(): assert reply == LIST_NIST256_REPLY +def test_unsupported(): + h = protocol.Handler(keys=[], signer=None) + reply = h.handle(b'\x09') + assert reply == b'\x00\x00\x00\x01\x05' + + def ecdsa_signer(label, blob): assert label == 'ssh://localhost' assert blob == NIST256_BLOB diff --git a/trezor_agent/tests/test_server.py b/trezor_agent/tests/test_server.py index 578dc1d..31487e6 100644 --- a/trezor_agent/tests/test_server.py +++ b/trezor_agent/tests/test_server.py @@ -41,16 +41,23 @@ def test_handle(): conn = FakeSocket() server.handle_connection(conn, handler) - msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES]) + msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) server.handle_connection(conn, handler) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' - msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES]) + msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) server.handle_connection(conn, handler) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00' + msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')]) + conn = FakeSocket(util.frame(msg)) + server.handle_connection(conn, handler) + conn.tx.seek(0) + reply = util.read_frame(conn.tx) + assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE')) + conn_mock = mock.Mock(spec=FakeSocket) conn_mock.recv.side_effect = [Exception, EOFError] server.handle_connection(conn=conn_mock, handler=None)