code refactoring
This commit is contained in:
80
sshagent/formats.py
Normal file
80
sshagent/formats.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import io
|
||||
import hashlib
|
||||
import base64
|
||||
import ecdsa
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from . import util
|
||||
|
||||
def fingerprint(blob):
|
||||
digest = hashlib.md5(blob).digest()
|
||||
return ':'.join('{:02x}'.format(c) for c in bytearray(digest))
|
||||
|
||||
DER_OCTET_STRING = b'\x04'
|
||||
|
||||
curve = ecdsa.NIST256p
|
||||
hashfunc = hashlib.sha256
|
||||
|
||||
def parse_pubkey(blob):
|
||||
s = io.BytesIO(blob)
|
||||
key_type = util.read_frame(s)
|
||||
log.debug('key type: %s', key_type)
|
||||
curve_name = util.read_frame(s)
|
||||
log.debug('curve name: %s', curve_name)
|
||||
point = util.read_frame(s)
|
||||
_type, point = point[:1], point[1:]
|
||||
assert _type == DER_OCTET_STRING
|
||||
size = len(point) // 2
|
||||
assert len(point) == 2 * size
|
||||
coords = (util.bytes2num(point[:size]), util.bytes2num(point[size:]))
|
||||
log.debug('coordinates: %s', coords)
|
||||
fp = fingerprint(blob)
|
||||
|
||||
point = ecdsa.ellipticcurve.Point(curve.curve, *coords)
|
||||
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashfunc)
|
||||
result = {
|
||||
'point': coords,
|
||||
'curve': curve_name,
|
||||
'fingerprint': fp,
|
||||
'type': key_type,
|
||||
'blob': blob,
|
||||
'size': size,
|
||||
'verifying_key': vk
|
||||
}
|
||||
return result
|
||||
|
||||
def load_public_key(filename):
|
||||
with open(filename) as f:
|
||||
return parse_public_key(f.read())
|
||||
|
||||
def parse_public_key(data):
|
||||
file_type, base64blob, name = data.split()
|
||||
blob = base64.b64decode(base64blob)
|
||||
result = parse_pubkey(blob)
|
||||
result['name'] = name.encode('ascii')
|
||||
assert result['type'] == file_type.encode('ascii')
|
||||
log.debug('loaded %s %s', file_type, result['fingerprint'])
|
||||
return result
|
||||
|
||||
def decompress_pubkey(pub):
|
||||
P = curve.curve.p()
|
||||
A = curve.curve.a()
|
||||
B = curve.curve.b()
|
||||
x = util.bytes2num(pub[1:33])
|
||||
beta = pow(int(x*x*x+A*x+B), int((P+1)//4), int(P))
|
||||
y = (P-beta) if ((beta + ord(pub[0])) % 2) else beta
|
||||
return (x, y)
|
||||
|
||||
|
||||
def export_public_key(pubkey, label):
|
||||
x, y = decompress_pubkey(pubkey)
|
||||
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
|
||||
vk = ecdsa.VerifyingKey.from_public_point(point, curve=curve,
|
||||
hashfunc=hashfunc)
|
||||
key_type = 'ecdsa-sha2-nistp256'
|
||||
curve_name = 'nistp256'
|
||||
blobs = map(util.frame, [key_type, curve_name, '\x04' + vk.to_string()])
|
||||
b64 = base64.b64encode(''.join(blobs))
|
||||
return '{} {} {}\n'.format(key_type, b64, label)
|
||||
@@ -1,92 +1,11 @@
|
||||
import io
|
||||
import struct
|
||||
import hashlib
|
||||
import ecdsa
|
||||
import base64
|
||||
|
||||
from . import util
|
||||
from . import formats
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def send(conn, data, fmt=None):
|
||||
if fmt:
|
||||
data = struct.pack(fmt, *data)
|
||||
conn.sendall(data)
|
||||
|
||||
def recv(conn, size):
|
||||
try:
|
||||
fmt = size
|
||||
size = struct.calcsize(fmt)
|
||||
except TypeError:
|
||||
fmt = None
|
||||
try:
|
||||
_read = conn.recv
|
||||
except AttributeError:
|
||||
_read = conn.read
|
||||
|
||||
res = io.BytesIO()
|
||||
while size > 0:
|
||||
buf = _read(size)
|
||||
if not buf:
|
||||
raise EOFError
|
||||
size = size - len(buf)
|
||||
res.write(buf)
|
||||
res = res.getvalue()
|
||||
if fmt:
|
||||
return struct.unpack(fmt, res)
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def read_frame(conn):
|
||||
size, = recv(conn, '>L')
|
||||
return recv(conn, size)
|
||||
|
||||
def bytes2num(s):
|
||||
res = 0
|
||||
for i, c in enumerate(reversed(bytearray(s))):
|
||||
res += c << (i * 8)
|
||||
return res
|
||||
|
||||
|
||||
def parse_pubkey(blob):
|
||||
s = io.BytesIO(blob)
|
||||
key_type = read_frame(s)
|
||||
log.debug('key type: %s', key_type)
|
||||
curve = read_frame(s)
|
||||
log.debug('curve name: %s', curve)
|
||||
point = read_frame(s)
|
||||
_type, point = point[:1], point[1:]
|
||||
assert _type == DER_OCTET_STRING
|
||||
size = len(point) // 2
|
||||
assert len(point) == 2 * size
|
||||
coords = map(bytes2num, [point[:size], point[size:]])
|
||||
log.debug('coordinates: %s', coords)
|
||||
fp = fingerprint(blob)
|
||||
result = {
|
||||
'point': tuple(coords), 'curve': curve,
|
||||
'fingerprint': fp,
|
||||
'type': key_type,
|
||||
'blob': blob, 'size': size
|
||||
}
|
||||
return result
|
||||
|
||||
def list_keys(c):
|
||||
send(c, [0x1, 0xB], '>LB')
|
||||
buf = io.BytesIO(read_frame(c))
|
||||
assert recv(buf, '>B') == (0xC,)
|
||||
num, = recv(buf, '>L')
|
||||
for i in range(num):
|
||||
k = parse_pubkey(read_frame(buf))
|
||||
k['comment'] = read_frame(buf)
|
||||
yield k
|
||||
|
||||
def frame(*msgs):
|
||||
res = io.BytesIO()
|
||||
for msg in msgs:
|
||||
res.write(msg)
|
||||
msg = res.getvalue()
|
||||
return pack('L', len(msg)) + msg
|
||||
|
||||
SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
|
||||
SSH_AGENT_RSA_IDENTITIES_ANSWER = 2
|
||||
|
||||
@@ -100,39 +19,34 @@ SSH2_AGENTC_ADD_IDENTITY = 17
|
||||
SSH2_AGENTC_REMOVE_IDENTITY = 18
|
||||
SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
|
||||
|
||||
def pack(fmt, *args):
|
||||
return struct.pack('>' + fmt, *args)
|
||||
def list_keys(c):
|
||||
util.send(c, [0x1, 0xB], '>LB')
|
||||
buf = io.BytesIO(util.read_frame(c))
|
||||
assert util.recv(buf, '>B') == (0xC,)
|
||||
num, = util.recv(buf, '>L')
|
||||
for i in range(num):
|
||||
k = formats.parse_pubkey(util.read_frame(buf))
|
||||
k['comment'] = util.read_frame(buf)
|
||||
yield k
|
||||
|
||||
def legacy_pubs(buf, keys, signer):
|
||||
code = pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
||||
num = pack('L', 0) # no SSH v1 keys
|
||||
return frame(code, num)
|
||||
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
||||
num = util.pack('L', 0) # no SSH v1 keys
|
||||
return util.frame(code, num)
|
||||
|
||||
def list_pubs(buf, keys, signer):
|
||||
code = pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
||||
num = pack('L', len(keys))
|
||||
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
||||
num = util.pack('L', len(keys))
|
||||
log.debug('available keys: %s', [k['name'] for k in keys])
|
||||
for i, k in enumerate(keys):
|
||||
log.debug('%2d) %s', i+1, k['fingerprint'])
|
||||
pubs = [frame(k['blob']) + frame(k['name']) for k in keys]
|
||||
return frame(code, num, *pubs)
|
||||
|
||||
def fingerprint(blob):
|
||||
digest = hashlib.md5(blob).digest()
|
||||
return ':'.join('{:02x}'.format(c) for c in bytearray(digest))
|
||||
|
||||
def num2bytes(value, size):
|
||||
res = []
|
||||
for i in range(size):
|
||||
res.append(value & 0xFF)
|
||||
value = value >> 8
|
||||
assert value == 0
|
||||
return bytearray(list(reversed(res)))
|
||||
pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys]
|
||||
return util.frame(code, num, *pubs)
|
||||
|
||||
def sign_message(buf, keys, signer):
|
||||
key = parse_pubkey(read_frame(buf))
|
||||
key = formats.parse_pubkey(util.read_frame(buf))
|
||||
log.debug('looking for %s', key['fingerprint'])
|
||||
blob = read_frame(buf)
|
||||
blob = util.read_frame(buf)
|
||||
|
||||
for k in keys:
|
||||
if (k['fingerprint']) == (key['fingerprint']):
|
||||
@@ -145,27 +59,23 @@ def sign_message(buf, keys, signer):
|
||||
log.debug('signing %d-byte blob', len(blob))
|
||||
r, s = signer(label=k['name'], blob=blob)
|
||||
signature = (r, s)
|
||||
|
||||
log.debug('signature: %s', signature)
|
||||
|
||||
curve = ecdsa.curves.NIST256p
|
||||
point = ecdsa.ellipticcurve.Point(curve.curve, *key['point'])
|
||||
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashlib.sha256)
|
||||
success = vk.verify(signature=signature, data=blob,
|
||||
sigdecode=lambda sig, _: sig)
|
||||
success = key['verifying_key'].verify(signature=signature, data=blob,
|
||||
sigdecode=lambda sig, _: sig)
|
||||
log.info('signature status: %s', 'OK' if success else 'ERROR')
|
||||
if not success:
|
||||
raise ValueError('invalid signature')
|
||||
|
||||
sig_bytes = io.BytesIO()
|
||||
for x in signature:
|
||||
sig_bytes.write(frame(b'\x00' + num2bytes(x, key['size'])))
|
||||
sig_bytes.write(util.frame(b'\x00' + util.num2bytes(x, key['size'])))
|
||||
sig_bytes = sig_bytes.getvalue()
|
||||
log.debug('signature size: %d bytes', len(sig_bytes))
|
||||
|
||||
data = frame(frame(key['type']), frame(sig_bytes))
|
||||
code = pack('B', SSH2_AGENT_SIGN_RESPONSE)
|
||||
return frame(code, data)
|
||||
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))
|
||||
code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE)
|
||||
return util.frame(code, data)
|
||||
|
||||
handlers = {
|
||||
SSH_AGENTC_REQUEST_RSA_IDENTITIES: legacy_pubs,
|
||||
@@ -173,36 +83,12 @@ handlers = {
|
||||
SSH2_AGENTC_SIGN_REQUEST: sign_message,
|
||||
}
|
||||
|
||||
def handle_connection(conn, keys, signer):
|
||||
try:
|
||||
log.debug('welcome agent')
|
||||
while True:
|
||||
msg = read_frame(conn)
|
||||
buf = io.BytesIO(msg)
|
||||
code, = recv(buf, '>B')
|
||||
log.debug('request: %d bytes', len(msg))
|
||||
handler = handlers[code]
|
||||
log.debug('calling %s()', handler.__name__)
|
||||
reply = handler(buf=buf, keys=keys, signer=signer)
|
||||
log.debug('reply: %d bytes', len(reply))
|
||||
send(conn, reply)
|
||||
except EOFError:
|
||||
log.debug('goodbye agent')
|
||||
except:
|
||||
log.exception('error')
|
||||
raise
|
||||
|
||||
DER_OCTET_STRING = b'\x04'
|
||||
|
||||
def load_public_key(filename):
|
||||
with open(filename) as f:
|
||||
return parse_public_key(f.read())
|
||||
|
||||
def parse_public_key(data):
|
||||
file_type, base64blob, name = data.split()
|
||||
blob = base64.b64decode(base64blob)
|
||||
result = parse_pubkey(blob)
|
||||
result['name'] = name.encode('ascii')
|
||||
assert result['type'] == file_type.encode('ascii')
|
||||
log.debug('loaded %s %s', file_type, result['fingerprint'])
|
||||
return result
|
||||
def handle_message(msg, keys, signer):
|
||||
log.debug('request: %d bytes', len(msg))
|
||||
buf = io.BytesIO(msg)
|
||||
code, = util.recv(buf, '>B')
|
||||
handler = handlers[code]
|
||||
log.debug('calling %s()', handler.__name__)
|
||||
reply = handler(buf=buf, keys=keys, signer=signer)
|
||||
log.debug('reply: %d bytes', len(reply))
|
||||
return reply
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
import socket
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
import contextlib
|
||||
@@ -9,8 +7,9 @@ import threading
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import protocol
|
||||
|
||||
from . import protocol
|
||||
from . import formats
|
||||
from . import util
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unix_domain_socket_server(sock_path):
|
||||
@@ -29,6 +28,18 @@ def unix_domain_socket_server(sock_path):
|
||||
finally:
|
||||
os.remove(sock_path)
|
||||
|
||||
def handle_connection(conn, keys, signer):
|
||||
try:
|
||||
log.debug('welcome agent')
|
||||
while True:
|
||||
msg = util.read_frame(conn)
|
||||
reply = protocol.handle_message(msg=msg, keys=keys, signer=signer)
|
||||
util.send(conn, reply)
|
||||
except EOFError:
|
||||
log.debug('goodbye agent')
|
||||
except:
|
||||
log.exception('error')
|
||||
raise
|
||||
|
||||
def server_thread(server, keys, signer):
|
||||
log.debug('server thread started')
|
||||
@@ -40,7 +51,7 @@ def server_thread(server, keys, signer):
|
||||
log.debug('server error: %s', e, exc_info=True)
|
||||
break
|
||||
with contextlib.closing(conn):
|
||||
protocol.handle_connection(conn, keys, signer)
|
||||
handle_connection(conn, keys, signer)
|
||||
log.debug('server thread stopped')
|
||||
|
||||
|
||||
@@ -70,7 +81,7 @@ def serve(key_files, command, signer, sock_path=None):
|
||||
if sock_path is None:
|
||||
sock_path = tempfile.mktemp(prefix='ssh-agent-')
|
||||
|
||||
keys = [protocol.parse_public_key(k) for k in key_files]
|
||||
keys = [formats.parse_public_key(k) for k in key_files]
|
||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
with unix_domain_socket_server(sock_path) as server:
|
||||
with spawn(server_thread, server=server, keys=keys, signer=signer):
|
||||
@@ -79,6 +90,4 @@ def serve(key_files, command, signer, sock_path=None):
|
||||
finally:
|
||||
log.debug('closing server')
|
||||
server.shutdown(socket.SHUT_RD)
|
||||
|
||||
log.info('exitcode: %d', ret)
|
||||
sys.exit(ret)
|
||||
return ret
|
||||
|
||||
@@ -1,50 +1,20 @@
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
import binascii
|
||||
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport
|
||||
from trezorlib.types_pb2 import IdentityType
|
||||
|
||||
import ecdsa
|
||||
import bitcoin
|
||||
import hashlib
|
||||
from . import util
|
||||
from . import formats
|
||||
|
||||
|
||||
import protocol
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
curve = ecdsa.NIST256p
|
||||
hashfunc = hashlib.sha256
|
||||
|
||||
|
||||
def decode_pubkey(pub):
|
||||
P = curve.curve.p()
|
||||
A = curve.curve.a()
|
||||
B = curve.curve.b()
|
||||
x = bitcoin.decode(pub[1:33], 256)
|
||||
beta = pow(int(x*x*x+A*x+B), int((P+1)//4), int(P))
|
||||
y = (P-beta) if ((beta + bitcoin.from_byte_to_int(pub[0])) % 2) else beta
|
||||
return (x, y)
|
||||
|
||||
|
||||
def export_public_key(pubkey, label):
|
||||
x, y = decode_pubkey(pubkey)
|
||||
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
|
||||
vk = ecdsa.VerifyingKey.from_public_point(point, curve=curve,
|
||||
hashfunc=hashfunc)
|
||||
key_type = 'ecdsa-sha2-nistp256'
|
||||
curve_name = 'nistp256'
|
||||
blobs = map(protocol.frame, [key_type, curve_name, '\x04' + vk.to_string()])
|
||||
b64 = base64.b64encode(''.join(blobs))
|
||||
return '{} {} {}\n'.format(key_type, b64, label)
|
||||
|
||||
|
||||
def label_addr(ident):
|
||||
index = '\x00' * 4
|
||||
addr = index + '{}://{}'.format(ident.proto, ident.host)
|
||||
h = bytearray(hashfunc(addr).digest())
|
||||
h = bytearray(formats.hashfunc(addr).digest())
|
||||
|
||||
address_n = [0] * 5
|
||||
address_n[0] = 13
|
||||
@@ -96,8 +66,8 @@ class Client(object):
|
||||
s = self.client.sign_identity(identity=ident,
|
||||
challenge_hidden=blob,
|
||||
challenge_visual=request)
|
||||
r = protocol.bytes2num(s.signature[:32])
|
||||
s = protocol.bytes2num(s.signature[32:])
|
||||
r = util.bytes2num(s.signature[:32])
|
||||
s = util.bytes2num(s.signature[32:])
|
||||
return (r, s)
|
||||
|
||||
|
||||
@@ -105,14 +75,14 @@ def parse_ssh_blob(data):
|
||||
res = {}
|
||||
if data:
|
||||
i = io.BytesIO(data)
|
||||
res['nonce'] = protocol.read_frame(i)
|
||||
res['nonce'] = util.read_frame(i)
|
||||
i.read(1) # TBD
|
||||
res['user'] = protocol.read_frame(i)
|
||||
res['conn'] = protocol.read_frame(i)
|
||||
res['auth'] = protocol.read_frame(i)
|
||||
res['user'] = util.read_frame(i)
|
||||
res['conn'] = util.read_frame(i)
|
||||
res['auth'] = util.read_frame(i)
|
||||
i.read(1) # TBD
|
||||
res['key_type'] = protocol.read_frame(i)
|
||||
res['pubkey'] = protocol.read_frame(i)
|
||||
res['key_type'] = util.read_frame(i)
|
||||
res['pubkey'] = util.read_frame(i)
|
||||
log.debug('%s: user %r via %r (%r)',
|
||||
res['conn'], res['user'], res['auth'], res['key_type'])
|
||||
return res
|
||||
|
||||
@@ -3,8 +3,9 @@ import argparse
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import trezor
|
||||
import server
|
||||
from . import trezor
|
||||
from . import server
|
||||
from . import formats
|
||||
|
||||
def main():
|
||||
fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s]'
|
||||
@@ -24,7 +25,8 @@ def main():
|
||||
key_files = []
|
||||
for label in args.labels:
|
||||
pubkey = client.get_public_key(label=label)
|
||||
key_files.append(trezor.export_public_key(pubkey=pubkey, label=label))
|
||||
key_file = formats.export_public_key(pubkey=pubkey, label=label)
|
||||
key_files.append(key_file)
|
||||
|
||||
if not args.command:
|
||||
sys.stdout.write(''.join(key_files))
|
||||
@@ -32,12 +34,18 @@ def main():
|
||||
|
||||
signer = client.sign_ssh_challenge
|
||||
|
||||
ret = -1
|
||||
try:
|
||||
server.serve(key_files=key_files, command=args.command, signer=signer)
|
||||
ret = server.serve(
|
||||
key_files=key_files,
|
||||
command=args.command,
|
||||
signer=signer)
|
||||
log.info('exitcode: %d', ret)
|
||||
except KeyboardInterrupt:
|
||||
log.info('server stopped')
|
||||
except Exception as e:
|
||||
log.warning(e, exc_info=True)
|
||||
sys.exit(ret)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
60
sshagent/util.py
Normal file
60
sshagent/util.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import struct
|
||||
import io
|
||||
|
||||
def send(conn, data, fmt=None):
|
||||
if fmt:
|
||||
data = struct.pack(fmt, *data)
|
||||
conn.sendall(data)
|
||||
|
||||
def recv(conn, size):
|
||||
try:
|
||||
fmt = size
|
||||
size = struct.calcsize(fmt)
|
||||
except TypeError:
|
||||
fmt = None
|
||||
try:
|
||||
_read = conn.recv
|
||||
except AttributeError:
|
||||
_read = conn.read
|
||||
|
||||
res = io.BytesIO()
|
||||
while size > 0:
|
||||
buf = _read(size)
|
||||
if not buf:
|
||||
raise EOFError
|
||||
size = size - len(buf)
|
||||
res.write(buf)
|
||||
res = res.getvalue()
|
||||
if fmt:
|
||||
return struct.unpack(fmt, res)
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def read_frame(conn):
|
||||
size, = recv(conn, '>L')
|
||||
return recv(conn, size)
|
||||
|
||||
def bytes2num(s):
|
||||
res = 0
|
||||
for i, c in enumerate(reversed(bytearray(s))):
|
||||
res += c << (i * 8)
|
||||
return res
|
||||
|
||||
def num2bytes(value, size):
|
||||
res = []
|
||||
for i in range(size):
|
||||
res.append(value & 0xFF)
|
||||
value = value >> 8
|
||||
assert value == 0
|
||||
return bytearray(list(reversed(res)))
|
||||
|
||||
def pack(fmt, *args):
|
||||
return struct.pack('>' + fmt, *args)
|
||||
|
||||
def frame(*msgs):
|
||||
res = io.BytesIO()
|
||||
for msg in msgs:
|
||||
res.write(msg)
|
||||
msg = res.getvalue()
|
||||
return pack('L', len(msg)) + msg
|
||||
Reference in New Issue
Block a user