protocol: use Handler class and fix pylint warnings
This commit is contained in:
@@ -20,70 +20,80 @@ SSH2_AGENTC_REMOVE_IDENTITY = 18
|
|||||||
SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
|
SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
|
||||||
|
|
||||||
|
|
||||||
def legacy_pubs(buf, keys, signer):
|
class Handler(object):
|
||||||
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
|
||||||
num = util.pack('L', 0) # no SSH v1 keys
|
|
||||||
return util.frame(code, num)
|
|
||||||
|
|
||||||
|
def __init__(self, keys, signer):
|
||||||
|
self.public_keys = keys
|
||||||
|
self.signer = signer
|
||||||
|
|
||||||
def list_pubs(buf, keys, signer):
|
self.methods = {
|
||||||
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
SSH_AGENTC_REQUEST_RSA_IDENTITIES: Handler.legacy_pubs,
|
||||||
num = util.pack('L', len(keys))
|
SSH2_AGENTC_REQUEST_IDENTITIES: self.list_pubs,
|
||||||
log.debug('available keys: %s', [k['name'] for k in keys])
|
SSH2_AGENTC_SIGN_REQUEST: self.sign_message,
|
||||||
for i, k in enumerate(keys):
|
}
|
||||||
log.debug('%2d) %s', i+1, k['fingerprint'])
|
|
||||||
pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys]
|
|
||||||
return util.frame(code, num, *pubs)
|
|
||||||
|
|
||||||
|
def handle(self, msg):
|
||||||
|
log.debug('request: %d bytes', len(msg))
|
||||||
|
buf = io.BytesIO(msg)
|
||||||
|
code, = util.recv(buf, '>B')
|
||||||
|
method = self.methods[code]
|
||||||
|
log.debug('calling %s()', method.__name__)
|
||||||
|
reply = method(buf=buf)
|
||||||
|
log.debug('reply: %d bytes', len(reply))
|
||||||
|
return reply
|
||||||
|
|
||||||
def sign_message(buf, keys, signer):
|
@staticmethod
|
||||||
key = formats.parse_pubkey(util.read_frame(buf))
|
def legacy_pubs(buf):
|
||||||
log.debug('looking for %s', key['fingerprint'])
|
''' SSH v1 public keys are not supported '''
|
||||||
blob = util.read_frame(buf)
|
assert not buf.read()
|
||||||
|
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
||||||
|
num = util.pack('L', 0) # no SSH v1 keys
|
||||||
|
return util.frame(code, num)
|
||||||
|
|
||||||
for k in keys:
|
def list_pubs(self, buf):
|
||||||
if (k['fingerprint']) == (key['fingerprint']):
|
''' SSH v2 public keys are serialized and returned. '''
|
||||||
log.debug('using key %r (%s)', k['name'], k['fingerprint'])
|
assert not buf.read()
|
||||||
key = k
|
keys = self.public_keys
|
||||||
break
|
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
||||||
else:
|
num = util.pack('L', len(keys))
|
||||||
raise ValueError('key not found')
|
log.debug('available keys: %s', [k['name'] for k in keys])
|
||||||
|
for i, k in enumerate(keys):
|
||||||
|
log.debug('%2d) %s', i+1, k['fingerprint'])
|
||||||
|
pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys]
|
||||||
|
return util.frame(code, num, *pubs)
|
||||||
|
|
||||||
log.debug('signing %d-byte blob', len(blob))
|
def sign_message(self, buf):
|
||||||
r, s = signer(label=key['name'], blob=blob)
|
''' SSH v2 public key authentication is performed. '''
|
||||||
signature = (r, s)
|
key = formats.parse_pubkey(util.read_frame(buf))
|
||||||
log.debug('signature: %s', signature)
|
log.debug('looking for %s', key['fingerprint'])
|
||||||
|
blob = util.read_frame(buf)
|
||||||
|
|
||||||
success = key['verifying_key'].verify(signature=signature, data=blob,
|
for k in self.public_keys:
|
||||||
sigdecode=lambda sig, _: sig)
|
if (k['fingerprint']) == (key['fingerprint']):
|
||||||
log.info('signature status: %s', 'OK' if success else 'ERROR')
|
log.debug('using key %r (%s)', k['name'], k['fingerprint'])
|
||||||
if not success:
|
key = k
|
||||||
raise ValueError('invalid signature')
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError('key not found')
|
||||||
|
|
||||||
sig_bytes = io.BytesIO()
|
log.debug('signing %d-byte blob', len(blob))
|
||||||
for x in signature:
|
r, s = self.signer(label=key['name'], blob=blob)
|
||||||
sig_bytes.write(util.frame(b'\x00' + util.num2bytes(x, key['size'])))
|
signature = (r, s)
|
||||||
sig_bytes = sig_bytes.getvalue()
|
log.debug('signature: %s', signature)
|
||||||
log.debug('signature size: %d bytes', len(sig_bytes))
|
|
||||||
|
|
||||||
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))
|
success = key['verifying_key'].verify(signature=signature, data=blob,
|
||||||
code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE)
|
sigdecode=lambda sig, _: sig)
|
||||||
return util.frame(code, data)
|
log.info('signature status: %s', 'OK' if success else 'ERROR')
|
||||||
|
if not success:
|
||||||
|
raise ValueError('invalid signature')
|
||||||
|
|
||||||
|
sig_bytes = io.BytesIO()
|
||||||
|
for x in signature:
|
||||||
|
x_frame = util.frame(b'\x00' + util.num2bytes(x, key['size']))
|
||||||
|
sig_bytes.write(x_frame)
|
||||||
|
sig_bytes = sig_bytes.getvalue()
|
||||||
|
log.debug('signature size: %d bytes', len(sig_bytes))
|
||||||
|
|
||||||
handlers = {
|
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))
|
||||||
SSH_AGENTC_REQUEST_RSA_IDENTITIES: legacy_pubs,
|
code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE)
|
||||||
SSH2_AGENTC_REQUEST_IDENTITIES: list_pubs,
|
return util.frame(code, data)
|
||||||
SSH2_AGENTC_SIGN_REQUEST: sign_message,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def handle_message(msg, keys, signer):
|
|
||||||
log.debug('request: %d bytes', len(msg))
|
|
||||||
buf = io.BytesIO(msg)
|
|
||||||
code, = util.recv(buf, '>B')
|
|
||||||
handler = handlers[code]
|
|
||||||
log.debug('calling %s()', handler.__name__)
|
|
||||||
reply = handler(buf=buf, keys=keys, signer=signer)
|
|
||||||
log.debug('reply: %d bytes', len(reply))
|
|
||||||
return reply
|
|
||||||
|
|||||||
@@ -31,12 +31,12 @@ def unix_domain_socket_server(sock_path):
|
|||||||
os.remove(sock_path)
|
os.remove(sock_path)
|
||||||
|
|
||||||
|
|
||||||
def handle_connection(conn, keys, signer):
|
def handle_connection(conn, handler):
|
||||||
try:
|
try:
|
||||||
log.debug('welcome agent')
|
log.debug('welcome agent')
|
||||||
while True:
|
while True:
|
||||||
msg = util.read_frame(conn)
|
msg = util.read_frame(conn)
|
||||||
reply = protocol.handle_message(msg=msg, keys=keys, signer=signer)
|
reply = handler.handle(msg=msg)
|
||||||
util.send(conn, reply)
|
util.send(conn, reply)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
log.debug('goodbye agent')
|
log.debug('goodbye agent')
|
||||||
@@ -47,6 +47,7 @@ def handle_connection(conn, keys, signer):
|
|||||||
|
|
||||||
def server_thread(server, keys, signer):
|
def server_thread(server, keys, signer):
|
||||||
log.debug('server thread started')
|
log.debug('server thread started')
|
||||||
|
handler = protocol.Handler(keys=keys, signer=signer)
|
||||||
while True:
|
while True:
|
||||||
log.debug('waiting for connection on %s', server.getsockname())
|
log.debug('waiting for connection on %s', server.getsockname())
|
||||||
try:
|
try:
|
||||||
@@ -55,7 +56,7 @@ def server_thread(server, keys, signer):
|
|||||||
log.debug('server error: %s', e, exc_info=True)
|
log.debug('server error: %s', e, exc_info=True)
|
||||||
break
|
break
|
||||||
with contextlib.closing(conn):
|
with contextlib.closing(conn):
|
||||||
handle_connection(conn, keys, signer)
|
handle_connection(conn, handler)
|
||||||
log.debug('server thread stopped')
|
log.debug('server thread stopped')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user