Split the package into a shared library and separate per-device packages

This commit is contained in:
Roman Zeyde
2017-04-27 16:31:17 +03:00
parent eb525e1b62
commit 4af881b3cb
54 changed files with 229 additions and 157 deletions

1
libagent/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""SSH-agent implementation using hardware authentication devices."""

64
libagent/client.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Connection to hardware authentication device.
It is used for getting SSH public keys and ECDSA signing of server requests.
"""
import io
import logging
from . import formats, util
log = logging.getLogger(__name__)
class Client(object):
"""Client wrapper for SSH authentication device."""
def __init__(self, device):
"""Connect to hardware device."""
self.device = device
def export_public_keys(self, identities):
"""Export SSH public keys from the device."""
public_keys = []
with self.device:
for i in identities:
pubkey = self.device.pubkey(identity=i)
vk = formats.decompress_pubkey(pubkey=pubkey,
curve_name=i.curve_name)
public_keys.append(formats.export_public_key(vk=vk,
label=str(i)))
return public_keys
def sign_ssh_challenge(self, blob, identity):
"""Sign given blob using a private key on the device."""
msg = _parse_ssh_blob(blob)
log.debug('%s: user %r via %r (%r)',
msg['conn'], msg['user'], msg['auth'], msg['key_type'])
log.debug('nonce: %r', msg['nonce'])
fp = msg['public_key']['fingerprint']
log.debug('fingerprint: %s', fp)
log.debug('hidden challenge size: %d bytes', len(blob))
log.info('please confirm user "%s" login to "%s" using %s...',
msg['user'].decode('ascii'), identity,
self.device)
with self.device:
return self.device.sign(blob=blob, identity=identity)
def _parse_ssh_blob(data):
res = {}
i = io.BytesIO(data)
res['nonce'] = util.read_frame(i)
i.read(1) # SSH2_MSG_USERAUTH_REQUEST == 50 (from ssh2.h, line 108)
res['user'] = util.read_frame(i)
res['conn'] = util.read_frame(i)
res['auth'] = util.read_frame(i)
i.read(1) # have_sig == 1 (from sshconnect2.c, line 1056)
res['key_type'] = util.read_frame(i)
public_key = util.read_frame(i)
res['public_key'] = formats.parse_pubkey(public_key)
assert not i.read()
return res

View File

@@ -0,0 +1,3 @@
"""Cryptographic hardware device management."""
from . import interface

View File

@@ -0,0 +1,135 @@
"""Device abstraction layer."""
import hashlib
import io
import logging
import re
import struct
from .. import formats, util
log = logging.getLogger(__name__)
_identity_regexp = re.compile(''.join([
'^'
r'(?:(?P<proto>.*)://)?',
r'(?:(?P<user>.*)@)?',
r'(?P<host>.*?)',
r'(?::(?P<port>\w*))?',
r'(?P<path>/.*)?',
'$'
]))
def string_to_identity(identity_str):
"""Parse string into Identity dictionary."""
m = _identity_regexp.match(identity_str)
result = m.groupdict()
log.debug('parsed identity: %s', result)
return {k: v for k, v in result.items() if v}
def identity_to_string(identity_dict):
"""Dump Identity dictionary into its string representation."""
result = []
if identity_dict.get('proto'):
result.append(identity_dict['proto'] + '://')
if identity_dict.get('user'):
result.append(identity_dict['user'] + '@')
result.append(identity_dict['host'])
if identity_dict.get('port'):
result.append(':' + identity_dict['port'])
if identity_dict.get('path'):
result.append(identity_dict['path'])
log.debug('identity parts: %s', result)
return ''.join(result)
class Error(Exception):
"""Device-related error."""
class NotFoundError(Error):
"""Device could not be found."""
class DeviceError(Error):
"""Error during device operation."""
class Identity(object):
"""Represent SLIP-0013 identity, together with a elliptic curve choice."""
def __init__(self, identity_str, curve_name):
"""Configure for specific identity and elliptic curve usage."""
self.identity_dict = string_to_identity(identity_str)
self.curve_name = curve_name
def items(self):
"""Return a copy of identity_dict items."""
return self.identity_dict.items()
def __str__(self):
"""Return identity serialized to string."""
return '<{}|{}>'.format(identity_to_string(self.identity_dict), self.curve_name)
def get_bip32_address(self, ecdh=False):
"""Compute BIP32 derivation address according to SLIP-0013/0017."""
index = struct.pack('<L', self.identity_dict.get('index', 0))
addr = index + identity_to_string(self.identity_dict).encode('ascii')
log.debug('bip32 address string: %r', addr)
digest = hashlib.sha256(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
addr_0 = 17 if bool(ecdh) else 13
address_n = [addr_0] + list(util.recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
def get_curve_name(self, ecdh=False):
"""Return correct curve name for device operations."""
if ecdh:
return formats.get_ecdh_curve_name(self.curve_name)
else:
return self.curve_name
class Device(object):
"""Abstract cryptographic hardware device interface."""
def __init__(self):
"""C-tor."""
self.conn = None
def connect(self):
"""Connect to device, otherwise raise NotFoundError."""
raise NotImplementedError()
def __enter__(self):
"""Allow usage as context manager."""
self.conn = self.connect()
return self
def __exit__(self, *args):
"""Close and mark as disconnected."""
try:
self.conn.close()
except Exception as e: # pylint: disable=broad-except
log.exception('close failed: %s', e)
self.conn = None
def pubkey(self, identity, ecdh=False):
"""Get public key (as bytes)."""
raise NotImplementedError()
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
raise NotImplementedError()
def ecdh(self, identity, pubkey):
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
raise NotImplementedError()
def __str__(self):
"""Human-readable representation."""
return '{}'.format(self.__class__.__name__)

View File

@@ -0,0 +1,37 @@
"""KeepKey-related code (see https://www.keepkey.com/)."""
from . import trezor
from .. import formats
def _verify_support(identity, ecdh):
"""Make sure the device supports given configuration."""
protocol = identity.identity_dict['proto']
if protocol not in {'ssh'}:
raise NotImplementedError(
'Unsupported protocol: {}'.format(protocol))
if ecdh:
raise NotImplementedError('No support for ECDH')
if identity.curve_name not in {formats.CURVE_NIST256}:
raise NotImplementedError(
'Unsupported elliptic curve: {}'.format(identity.curve_name))
class KeepKey(trezor.Trezor):
"""Connection to KeepKey device."""
@property
def _defs(self):
from . import keepkey_defs
return keepkey_defs
required_version = '>=1.0.4'
def pubkey(self, identity, ecdh=False):
"""Return public key."""
_verify_support(identity, ecdh)
return trezor.Trezor.pubkey(self, identity=identity, ecdh=ecdh)
def ecdh(self, identity, pubkey):
"""No support for ECDH in KeepKey firmware."""
_verify_support(identity, ecdh=True)

View File

@@ -0,0 +1,9 @@
"""KeepKey-related definitions."""
# pylint: disable=unused-import,import-error
from keepkeylib.client import CallException as Error
from keepkeylib.client import KeepKeyClient as Client
from keepkeylib.messages_pb2 import PassphraseAck
from keepkeylib.transport_hid import HidTransport as Transport
from keepkeylib.types_pb2 import IdentityType

111
libagent/device/ledger.py Normal file
View File

@@ -0,0 +1,111 @@
"""Ledger-related code (see https://www.ledgerwallet.com/)."""
import binascii
import logging
import struct
from ledgerblue import comm # pylint: disable=import-error
from . import interface
log = logging.getLogger(__name__)
def _expand_path(path):
"""Convert BIP32 path into bytes."""
return b''.join((struct.pack('>I', e) for e in path))
def _convert_public_key(ecdsa_curve_name, result):
"""Convert Ledger reply into PublicKey object."""
if ecdsa_curve_name == 'nist256p1':
if (result[64] & 1) != 0:
result = bytearray([0x03]) + result[1:33]
else:
result = bytearray([0x02]) + result[1:33]
else:
result = result[1:]
keyX = bytearray(result[0:32])
keyY = bytearray(result[32:][::-1])
if (keyX[31] & 1) != 0:
keyY[31] |= 0x80
result = b'\x00' + bytes(keyY)
return bytes(result)
class LedgerNanoS(interface.Device):
"""Connection to Ledger Nano S device."""
def connect(self):
"""Enumerate and connect to the first USB HID interface."""
try:
return comm.getDongle()
except comm.CommException as e:
raise interface.NotFoundError(
'{} not connected: "{}"'.format(self, e))
def pubkey(self, identity, ecdh=False):
"""Get PublicKey object for specified BIP32 address and elliptic curve."""
curve_name = identity.get_curve_name(ecdh)
path = _expand_path(identity.get_bip32_address(ecdh))
if curve_name == 'nist256p1':
p2 = '01'
else:
p2 = '02'
apdu = '800200' + p2
apdu = binascii.unhexlify(apdu)
apdu += bytearray([len(path) + 1, len(path) // 4])
apdu += path
result = bytearray(self.conn.exchange(bytes(apdu)))[1:]
return _convert_public_key(curve_name, result)
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
path = _expand_path(identity.get_bip32_address(ecdh=False))
if identity.identity_dict['proto'] == 'ssh':
ins = '04'
p1 = '00'
else:
ins = '08'
p1 = '00'
if identity.curve_name == 'nist256p1':
p2 = '81' if identity.identity_dict['proto'] == 'ssh' else '01'
else:
p2 = '82' if identity.identity_dict['proto'] == 'ssh' else '02'
apdu = '80' + ins + p1 + p2
apdu = binascii.unhexlify(apdu)
apdu += bytearray([len(blob) + len(path) + 1])
apdu += bytearray([len(path) // 4]) + path
apdu += blob
result = bytearray(self.conn.exchange(bytes(apdu)))
if identity.curve_name == 'nist256p1':
offset = 3
length = result[offset]
r = result[offset+1:offset+1+length]
if r[0] == 0:
r = r[1:]
offset = offset + 1 + length + 1
length = result[offset]
s = result[offset+1:offset+1+length]
if s[0] == 0:
s = s[1:]
offset = offset + 1 + length
return bytes(r) + bytes(s)
else:
return bytes(result[:64])
def ecdh(self, identity, pubkey):
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
path = _expand_path(identity.get_bip32_address(ecdh=True))
if identity.curve_name == 'nist256p1':
p2 = '01'
else:
p2 = '02'
apdu = '800a00' + p2
apdu = binascii.unhexlify(apdu)
apdu += bytearray([len(pubkey) + len(path) + 1])
apdu += bytearray([len(path) // 4]) + path
apdu += pubkey
result = bytearray(self.conn.exchange(bytes(apdu)))
assert result[0] == 0x04
return bytes(result)

117
libagent/device/trezor.py Normal file
View File

@@ -0,0 +1,117 @@
"""TREZOR-related code (see http://bitcointrezor.com/)."""
import binascii
import logging
import os
import semver
from . import interface
log = logging.getLogger(__name__)
class Trezor(interface.Device):
"""Connection to TREZOR device."""
@property
def _defs(self):
from . import trezor_defs
# Allow using TREZOR bridge transport (instead of the HID default)
trezor_defs.Transport = {
'bridge': trezor_defs.BridgeTransport,
}.get(os.environ.get('TREZOR_TRANSPORT'), trezor_defs.HidTransport)
return trezor_defs
required_version = '>=1.4.0'
passphrase = os.environ.get('TREZOR_PASSPHRASE', '')
def connect(self):
"""Enumerate and connect to the first USB HID interface."""
def passphrase_handler(_):
log.debug('using %s passphrase for %s',
'non-empty' if self.passphrase else 'empty', self)
return self._defs.PassphraseAck(passphrase=self.passphrase)
for d in self._defs.Transport.enumerate():
log.debug('endpoint: %s', d)
transport = self._defs.Transport(d)
connection = self._defs.Client(transport)
connection.callback_PassphraseRequest = passphrase_handler
f = connection.features
log.debug('connected to %s %s', self, f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
current_version = '{}.{}.{}'.format(f.major_version,
f.minor_version,
f.patch_version)
log.debug('version : %s', current_version)
log.debug('revision : %s', binascii.hexlify(f.revision))
if not semver.match(current_version, self.required_version):
fmt = ('Please upgrade your {} firmware to {} version'
' (current: {})')
raise ValueError(fmt.format(self, self.required_version,
current_version))
connection.ping(msg='', pin_protection=True) # unlock PIN
return connection
raise interface.NotFoundError('{} not connected'.format(self))
def close(self):
"""Close connection."""
self.conn.close()
def pubkey(self, identity, ecdh=False):
"""Return public key."""
curve_name = identity.get_curve_name(ecdh=ecdh)
log.debug('"%s" getting public key (%s) from %s',
identity, curve_name, self)
addr = identity.get_bip32_address(ecdh=ecdh)
result = self.conn.get_public_node(n=addr,
ecdsa_curve_name=curve_name)
log.debug('result: %s', result)
return result.node.public_key
def _identity_proto(self, identity):
result = self._defs.IdentityType()
for name, value in identity.items():
setattr(result, name, value)
return result
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
curve_name = identity.get_curve_name(ecdh=False)
log.debug('"%s" signing %r (%s) on %s',
identity, blob, curve_name, self)
try:
result = self.conn.sign_identity(
identity=self._identity_proto(identity),
challenge_hidden=blob,
challenge_visual='',
ecdsa_curve_name=curve_name)
log.debug('result: %s', result)
assert len(result.signature) == 65
assert result.signature[:1] == b'\x00'
return result.signature[1:]
except self._defs.Error as e:
msg = '{} error: {}'.format(self, e)
log.debug(msg, exc_info=True)
raise interface.DeviceError(msg)
def ecdh(self, identity, pubkey):
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
curve_name = identity.get_curve_name(ecdh=True)
log.debug('"%s" shared session key (%s) for %r from %s',
identity, curve_name, pubkey, self)
try:
result = self.conn.get_ecdh_session_key(
identity=self._identity_proto(identity),
peer_public_key=pubkey,
ecdsa_curve_name=curve_name)
log.debug('result: %s', result)
assert len(result.session_key) in {65, 33} # NIST256 or Curve25519
assert result.session_key[:1] == b'\x04'
return result.session_key
except self._defs.Error as e:
msg = '{} error: {}'.format(self, e)
log.debug(msg, exc_info=True)
raise interface.DeviceError(msg)

View File

@@ -0,0 +1,10 @@
"""TREZOR-related definitions."""
# pylint: disable=unused-import,import-error
from trezorlib.client import CallException as Error
from trezorlib.client import TrezorClient as Client
from trezorlib.messages_pb2 import PassphraseAck
from trezorlib.transport_bridge import BridgeTransport
from trezorlib.transport_hid import HidTransport
from trezorlib.types_pb2 import IdentityType

208
libagent/formats.py Normal file
View File

@@ -0,0 +1,208 @@
"""SSH format parsing and formatting tools."""
import base64
import hashlib
import io
import logging
import ecdsa
import ed25519
from . import util
log = logging.getLogger(__name__)
# Supported ECDSA curves (for SSH and GPG)
CURVE_NIST256 = 'nist256p1'
CURVE_ED25519 = 'ed25519'
SUPPORTED_CURVES = {CURVE_NIST256, CURVE_ED25519}
# Supported ECDH curves (for GPG)
ECDH_NIST256 = 'nist256p1'
ECDH_CURVE25519 = 'curve25519'
# SSH key types
SSH_NIST256_DER_OCTET = b'\x04'
SSH_NIST256_KEY_PREFIX = b'ecdsa-sha2-'
SSH_NIST256_CURVE_NAME = b'nistp256'
SSH_NIST256_KEY_TYPE = SSH_NIST256_KEY_PREFIX + SSH_NIST256_CURVE_NAME
SSH_ED25519_KEY_TYPE = b'ssh-ed25519'
SUPPORTED_KEY_TYPES = {SSH_NIST256_KEY_TYPE, SSH_ED25519_KEY_TYPE}
hashfunc = hashlib.sha256
def fingerprint(blob):
"""
Compute SSH fingerprint for specified blob.
See https://en.wikipedia.org/wiki/Public_key_fingerprint for details.
"""
digest = hashlib.md5(blob).digest()
return ':'.join('{:02x}'.format(c) for c in bytearray(digest))
def parse_pubkey(blob):
"""
Parse SSH public key from given blob.
Construct a verifier for ECDSA signatures.
The verifier returns the signatures in the required SSH format.
Currently, NIST256P1 and ED25519 elliptic curves are supported.
"""
fp = fingerprint(blob)
s = io.BytesIO(blob)
key_type = util.read_frame(s)
log.debug('key type: %s', key_type)
assert key_type in SUPPORTED_KEY_TYPES, key_type
result = {'blob': blob, 'type': key_type, 'fingerprint': fp}
if key_type == SSH_NIST256_KEY_TYPE:
curve_name = util.read_frame(s)
log.debug('curve name: %s', curve_name)
point = util.read_frame(s)
assert s.read() == b''
_type, point = point[:1], point[1:]
assert _type == SSH_NIST256_DER_OCTET
size = len(point) // 2
assert len(point) == 2 * size
coords = (util.bytes2num(point[:size]), util.bytes2num(point[size:]))
curve = ecdsa.NIST256p
point = ecdsa.ellipticcurve.Point(curve.curve, *coords)
def ecdsa_verifier(sig, msg):
assert len(sig) == 2 * size
sig_decode = ecdsa.util.sigdecode_string
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashfunc)
vk.verify(signature=sig, data=msg, sigdecode=sig_decode)
parts = [sig[:size], sig[size:]]
return b''.join([util.frame(b'\x00' + p) for p in parts])
result.update(point=coords, curve=CURVE_NIST256,
verifier=ecdsa_verifier)
if key_type == SSH_ED25519_KEY_TYPE:
pubkey = util.read_frame(s)
assert s.read() == b''
def ed25519_verify(sig, msg):
assert len(sig) == 64
vk = ed25519.VerifyingKey(pubkey)
vk.verify(sig, msg)
return sig
result.update(curve=CURVE_ED25519, verifier=ed25519_verify)
return result
def _decompress_ed25519(pubkey):
"""Load public key from the serialized blob (stripping the prefix byte)."""
if pubkey[:1] == b'\x00':
# set by Trezor fsm_msgSignIdentity() and fsm_msgGetPublicKey()
return ed25519.VerifyingKey(pubkey[1:])
def _decompress_nist256(pubkey):
"""
Load public key from the serialized blob.
The leading byte least-significant bit is used to decide how to recreate
the y-coordinate from the specified x-coordinate. See bitcoin/main.py#L198
(from https://github.com/vbuterin/pybitcointools/) for details.
"""
if pubkey[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33()
curve = ecdsa.NIST256p
P = curve.curve.p()
A = curve.curve.a()
B = curve.curve.b()
x = util.bytes2num(pubkey[1:33])
beta = pow(int(x * x * x + A * x + B), int((P + 1) // 4), int(P))
p0 = util.bytes2num(pubkey[:1])
y = (P - beta) if ((beta + p0) % 2) else beta
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
return ecdsa.VerifyingKey.from_public_point(point, curve=curve,
hashfunc=hashfunc)
def decompress_pubkey(pubkey, curve_name):
"""
Load public key from the serialized blob.
Raise ValueError on parsing error.
"""
vk = None
if len(pubkey) == 33:
decompress = {
CURVE_NIST256: _decompress_nist256,
CURVE_ED25519: _decompress_ed25519,
ECDH_CURVE25519: _decompress_ed25519,
}[curve_name]
vk = decompress(pubkey)
if not vk:
msg = 'invalid {!s} public key: {!r}'.format(curve_name, pubkey)
raise ValueError(msg)
return vk
def serialize_verifying_key(vk):
"""
Serialize a public key into SSH format (for exporting to text format).
Currently, NIST256P1 and ED25519 elliptic curves are supported.
Raise TypeError on unsupported key format.
"""
if isinstance(vk, ed25519.keys.VerifyingKey):
pubkey = vk.to_bytes()
key_type = SSH_ED25519_KEY_TYPE
blob = util.frame(SSH_ED25519_KEY_TYPE) + util.frame(pubkey)
return key_type, blob
if isinstance(vk, ecdsa.keys.VerifyingKey):
curve_name = SSH_NIST256_CURVE_NAME
key_blob = SSH_NIST256_DER_OCTET + vk.to_string()
parts = [SSH_NIST256_KEY_TYPE, curve_name, key_blob]
key_type = SSH_NIST256_KEY_TYPE
blob = b''.join([util.frame(p) for p in parts])
return key_type, blob
raise TypeError('unsupported {!r}'.format(vk))
def export_public_key(vk, label):
"""
Export public key to text format.
The resulting string can be written into a .pub file or
appended to the ~/.ssh/authorized_keys file.
"""
key_type, blob = serialize_verifying_key(vk)
log.debug('fingerprint: %s', fingerprint(blob))
b64 = base64.b64encode(blob).decode('ascii')
return '{} {} {}\n'.format(key_type.decode('ascii'), b64, label)
def import_public_key(line):
"""Parse public key textual format, as saved at a .pub file."""
log.debug('loading SSH public key: %r', line)
file_type, base64blob, name = line.split()
blob = base64.b64decode(base64blob)
result = parse_pubkey(blob)
result['name'] = name.encode('ascii')
assert result['type'] == file_type.encode('ascii')
log.debug('loaded %s public key: %s', file_type, result['fingerprint'])
return result
def get_ecdh_curve_name(signature_curve_name):
"""Return appropriate curve for ECDH for specified signing curve."""
return {
CURVE_NIST256: ECDH_NIST256,
CURVE_ED25519: ECDH_CURVE25519,
ECDH_CURVE25519: ECDH_CURVE25519,
}[signature_curve_name]

145
libagent/gpg/__init__.py Normal file
View File

@@ -0,0 +1,145 @@
"""
TREZOR support for ECDSA GPG signatures.
See these links for more details:
- https://www.gnupg.org/faq/whats-new-in-2.1.html
- https://tools.ietf.org/html/rfc4880
- https://tools.ietf.org/html/rfc6637
- https://tools.ietf.org/html/draft-irtf-cfrg-eddsa-05
"""
import argparse
import contextlib
import logging
import os
import sys
import time
import semver
from . import agent, client, encode, keyring, protocol
from .. import device, formats, server, util
log = logging.getLogger(__name__)
def export_public_key(device_type, args):
"""Generate a new pubkey for a new/existing GPG identity."""
log.warning('NOTE: in order to re-generate the exact same GPG key later, '
'run this command with "--time=%d" commandline flag (to set '
'the timestamp of the GPG key manually).', args.time)
c = client.Client(user_id=args.user_id, curve_name=args.ecdsa_curve,
device_type=device_type)
verifying_key = c.pubkey(ecdh=False)
decryption_key = c.pubkey(ecdh=True)
if args.subkey: # add as subkey
log.info('adding %s GPG subkey for "%s" to existing key',
args.ecdsa_curve, args.user_id)
# subkey for signing
signing_key = protocol.PublicKey(
curve_name=args.ecdsa_curve, created=args.time,
verifying_key=verifying_key, ecdh=False)
# subkey for encryption
encryption_key = protocol.PublicKey(
curve_name=formats.get_ecdh_curve_name(args.ecdsa_curve),
created=args.time, verifying_key=decryption_key, ecdh=True)
primary_bytes = keyring.export_public_key(args.user_id)
result = encode.create_subkey(primary_bytes=primary_bytes,
subkey=signing_key,
signer_func=c.sign)
result = encode.create_subkey(primary_bytes=result,
subkey=encryption_key,
signer_func=c.sign)
else: # add as primary
log.info('creating new %s GPG primary key for "%s"',
args.ecdsa_curve, args.user_id)
# primary key for signing
primary = protocol.PublicKey(
curve_name=args.ecdsa_curve, created=args.time,
verifying_key=verifying_key, ecdh=False)
# subkey for encryption
subkey = protocol.PublicKey(
curve_name=formats.get_ecdh_curve_name(args.ecdsa_curve),
created=args.time, verifying_key=decryption_key, ecdh=True)
result = encode.create_primary(user_id=args.user_id,
pubkey=primary,
signer_func=c.sign)
result = encode.create_subkey(primary_bytes=result,
subkey=subkey,
signer_func=c.sign)
sys.stdout.write(protocol.armor(result, 'PUBLIC KEY BLOCK'))
def run_create(device_type, args):
"""Export public GPG key."""
util.setup_logging(verbosity=args.verbose)
log.warning('This GPG tool is still in EXPERIMENTAL mode, '
'so please note that the API and features may '
'change without backwards compatibility!')
existing_gpg = keyring.gpg_version().decode('ascii')
required_gpg = '>=2.1.11'
if semver.match(existing_gpg, required_gpg):
export_public_key(device_type, args)
else:
log.error('Existing gpg2 has version "%s" (%s required)',
existing_gpg, required_gpg)
def run_unlock(device_type, args):
"""Unlock hardware device (for future interaction)."""
util.setup_logging(verbosity=args.verbose)
with device_type() as d:
log.info('unlocked %s device', d)
def run_agent(device_type):
"""Run a simple GPG-agent server."""
parser = argparse.ArgumentParser()
parser.add_argument('--homedir', default=os.environ.get('GNUPGHOME'))
args, _ = parser.parse_known_args()
assert args.homedir
config_file = os.path.join(args.homedir, 'gpg-agent.conf')
lines = (line.strip() for line in open(config_file))
lines = (line for line in lines if line and not line.startswith('#'))
config = dict(line.split(' ', 1) for line in lines)
util.setup_logging(verbosity=int(config['verbosity']),
filename=config['log-file'])
sock_path = keyring.get_agent_sock_path()
with server.unix_domain_socket_server(sock_path) as sock:
for conn in agent.yield_connections(sock):
with contextlib.closing(conn):
try:
agent.handle_connection(conn=conn, device_type=device_type)
except StopIteration:
log.info('stopping gpg-agent')
return
except Exception as e: # pylint: disable=broad-except
log.exception('gpg-agent failed: %s', e)
def main(device_type):
"""Parse command-line arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
p = subparsers.add_parser('create', help='Export public GPG key')
p.add_argument('user_id')
p.add_argument('-e', '--ecdsa-curve', default='nist256p1')
p.add_argument('-t', '--time', type=int, default=int(time.time()))
p.add_argument('-v', '--verbose', default=0, action='count')
p.add_argument('-s', '--subkey', default=False, action='store_true')
p.set_defaults(func=run_create)
p = subparsers.add_parser('unlock', help='Unlock the hardware device')
p.add_argument('-v', '--verbose', default=0, action='count')
p.set_defaults(func=run_unlock)
args = parser.parse_args()
return args.func(device_type=device_type, args=args)

171
libagent/gpg/agent.py Normal file
View File

@@ -0,0 +1,171 @@
"""GPG-agent utilities."""
import binascii
import logging
from . import client, decode, keyring, protocol
from .. import util
log = logging.getLogger(__name__)
def yield_connections(sock):
"""Run a server on the specified socket."""
while True:
log.debug('waiting for connection on %s', sock.getsockname())
try:
conn, _ = sock.accept()
except KeyboardInterrupt:
return
conn.settimeout(None)
log.debug('accepted connection on %s', sock.getsockname())
yield conn
def serialize(data):
"""Serialize data according to ASSUAN protocol."""
for c in [b'%', b'\n', b'\r']:
escaped = '%{:02X}'.format(ord(c)).encode('ascii')
data = data.replace(c, escaped)
return data
def sig_encode(r, s):
"""Serialize ECDSA signature data into GPG S-expression."""
r = serialize(util.num2bytes(r, 32))
s = serialize(util.num2bytes(s, 32))
return b'(7:sig-val(5:ecdsa(1:r32:' + r + b')(1:s32:' + s + b')))'
def open_connection(keygrip_bytes, device_type):
"""
Connect to the device for the specified keygrip.
Parse GPG public key to find the first user ID, which is used to
specify the correct signature/decryption key on the device.
"""
pubkey_dict, user_ids = decode.load_by_keygrip(
pubkey_bytes=keyring.export_public_keys(),
keygrip=keygrip_bytes)
# We assume the first user ID is used to generate TREZOR-based GPG keys.
user_id = user_ids[0]['value'].decode('ascii')
curve_name = protocol.get_curve_name_by_oid(pubkey_dict['curve_oid'])
ecdh = (pubkey_dict['algo'] == protocol.ECDH_ALGO_ID)
conn = client.Client(user_id, curve_name=curve_name, device_type=device_type)
pubkey = protocol.PublicKey(
curve_name=curve_name, created=pubkey_dict['created'],
verifying_key=conn.pubkey(ecdh=ecdh), ecdh=ecdh)
assert pubkey.key_id() == pubkey_dict['key_id']
assert pubkey.keygrip() == keygrip_bytes
return conn
def pksign(keygrip, digest, algo, device_type):
"""Sign a message digest using a private EC key."""
log.debug('signing %r digest (algo #%s)', digest, algo)
keygrip_bytes = binascii.unhexlify(keygrip)
conn = open_connection(keygrip_bytes, device_type=device_type)
r, s = conn.sign(binascii.unhexlify(digest))
result = sig_encode(r, s)
log.debug('result: %r', result)
return result
def _serialize_point(data):
prefix = '{}:'.format(len(data)).encode('ascii')
# https://www.gnupg.org/documentation/manuals/assuan/Server-responses.html
return b'(5:value' + serialize(prefix + data) + b')'
def parse_ecdh(line):
"""Parse ECDH request and return remote public key."""
prefix, line = line.split(b' ', 1)
assert prefix == b'D'
exp, leftover = keyring.parse(keyring.unescape(line))
log.debug('ECDH s-exp: %r', exp)
assert not leftover
label, exp = exp
assert label == b'enc-val'
assert exp[0] == b'ecdh'
items = exp[1:]
log.debug('ECDH parameters: %r', items)
return dict(items)[b'e']
def pkdecrypt(keygrip, conn, device_type):
"""Handle decryption using ECDH."""
for msg in [b'S INQUIRE_MAXLEN 4096', b'INQUIRE CIPHERTEXT']:
keyring.sendline(conn, msg)
line = keyring.recvline(conn)
assert keyring.recvline(conn) == b'END'
remote_pubkey = parse_ecdh(line)
keygrip_bytes = binascii.unhexlify(keygrip)
conn = open_connection(keygrip_bytes, device_type=device_type)
return _serialize_point(conn.ecdh(remote_pubkey))
@util.memoize
def have_key(keygrip, device_type):
"""Check if current keygrip correspond to a TREZOR-based key."""
try:
open_connection(keygrip_bytes=binascii.unhexlify(keygrip),
device_type=device_type)
return True
except KeyError as e:
log.warning('HAVEKEY(%s) failed: %s', keygrip, e)
return False
# pylint: disable=too-many-branches
def handle_connection(conn, device_type):
"""Handle connection from GPG binary using the ASSUAN protocol."""
keygrip = None
digest = None
algo = None
version = keyring.gpg_version() # "Clone" existing GPG version
keyring.sendline(conn, b'OK')
for line in keyring.iterlines(conn):
parts = line.split(b' ')
command = parts[0]
args = parts[1:]
if command in {b'RESET', b'OPTION', b'SETKEYDESC'}:
pass # reply with OK
elif command == b'GETINFO':
keyring.sendline(conn, b'D ' + version)
elif command == b'AGENT_ID':
keyring.sendline(conn, b'D TREZOR') # "Fake" agent ID
elif command in {b'SIGKEY', b'SETKEY'}:
keygrip, = args
elif command == b'SETHASH':
algo, digest = args
elif command == b'PKSIGN':
sig = pksign(keygrip, digest, algo, device_type=device_type)
keyring.sendline(conn, b'D ' + sig)
elif command == b'PKDECRYPT':
sec = pkdecrypt(keygrip, conn, device_type=device_type)
keyring.sendline(conn, b'D ' + sec)
elif command == b'HAVEKEY':
if not have_key(keygrip=args[0], device_type=device_type):
keyring.sendline(conn,
b'ERR 67108881 No secret key <GPG Agent>')
return
elif command == b'KEYINFO':
keygrip, = args
# Dummy reply (mainly for 'gpg --edit' to succeed).
# For details, see GnuPG agent KEYINFO command help.
# https://git.gnupg.org/cgi-bin/gitweb.cgi?p=gnupg.git;a=blob;f=agent/command.c;h=c8b34e9882076b1b724346787781f657cac75499;hb=refs/heads/master#l1082
fmt = 'S KEYINFO {0} X - - - - - - -'
keyring.sendline(conn, fmt.format(keygrip).encode('ascii'))
elif command == b'BYE':
return
elif command == b'KILLAGENT':
keyring.sendline(conn, b'OK')
raise StopIteration
else:
log.error('unknown request: %r', line)
return
keyring.sendline(conn, b'OK')

44
libagent/gpg/client.py Normal file
View File

@@ -0,0 +1,44 @@
"""Device abstraction layer for GPG operations."""
import logging
from .. import device, formats, util
log = logging.getLogger(__name__)
class Client(object):
"""Sign messages and get public keys from a hardware device."""
def __init__(self, user_id, curve_name, device_type):
"""Connect to the device and retrieve required public key."""
self.device = device_type()
self.user_id = user_id
self.identity = device.interface.Identity(
identity_str='gpg://', curve_name=curve_name)
self.identity.identity_dict['host'] = user_id
def pubkey(self, ecdh=False):
"""Return public key as VerifyingKey object."""
with self.device:
pubkey = self.device.pubkey(ecdh=ecdh, identity=self.identity)
return formats.decompress_pubkey(
pubkey=pubkey, curve_name=self.identity.curve_name)
def sign(self, digest):
"""Sign the digest and return a serialized signature."""
log.info('please confirm GPG signature on %s for "%s"...',
self.device, self.user_id)
if self.identity.curve_name == formats.CURVE_NIST256:
digest = digest[:32] # sign the first 256 bits
log.debug('signing digest: %s', util.hexlify(digest))
with self.device:
sig = self.device.sign(blob=digest, identity=self.identity)
return (util.bytes2num(sig[:32]), util.bytes2num(sig[32:]))
def ecdh(self, pubkey):
"""Derive shared secret using ECDH from remote public key."""
log.info('please confirm GPG decryption on %s for "%s"...',
self.device, self.user_id)
with self.device:
return self.device.ecdh(pubkey=pubkey, identity=self.identity)

318
libagent/gpg/decode.py Normal file
View File

@@ -0,0 +1,318 @@
"""Decoders for GPG v2 data structures."""
import base64
import functools
import hashlib
import io
import logging
import struct
import ecdsa
import ed25519
from . import protocol
from .. import util
log = logging.getLogger(__name__)
def parse_subpackets(s):
"""See https://tools.ietf.org/html/rfc4880#section-5.2.3.1 for details."""
subpackets = []
total_size = s.readfmt('>H')
data = s.read(total_size)
s = util.Reader(io.BytesIO(data))
while True:
try:
first = s.readfmt('B')
except EOFError:
break
if first < 192:
subpacket_len = first
elif first < 255:
subpacket_len = ((first - 192) << 8) + s.readfmt('B') + 192
else: # first == 255
subpacket_len = s.readfmt('>L')
subpackets.append(s.read(subpacket_len))
return subpackets
def parse_mpi(s):
"""See https://tools.ietf.org/html/rfc4880#section-3.2 for details."""
bits = s.readfmt('>H')
blob = bytearray(s.read(int((bits + 7) // 8)))
return sum(v << (8 * i) for i, v in enumerate(reversed(blob)))
def parse_mpis(s, n):
"""Parse multiple MPIs from stream."""
return [parse_mpi(s) for _ in range(n)]
def _parse_nist256p1_pubkey(mpi):
prefix, x, y = util.split_bits(mpi, 4, 256, 256)
if prefix != 4:
raise ValueError('Invalid MPI prefix: {}'.format(prefix))
point = ecdsa.ellipticcurve.Point(curve=ecdsa.NIST256p.curve,
x=x, y=y)
return ecdsa.VerifyingKey.from_public_point(
point=point, curve=ecdsa.curves.NIST256p,
hashfunc=hashlib.sha256)
def _parse_ed25519_pubkey(mpi):
prefix, value = util.split_bits(mpi, 8, 256)
if prefix != 0x40:
raise ValueError('Invalid MPI prefix: {}'.format(prefix))
return ed25519.VerifyingKey(util.num2bytes(value, size=32))
SUPPORTED_CURVES = {
b'\x2A\x86\x48\xCE\x3D\x03\x01\x07':
(_parse_nist256p1_pubkey, protocol.keygrip_nist256),
b'\x2B\x06\x01\x04\x01\xDA\x47\x0F\x01':
(_parse_ed25519_pubkey, protocol.keygrip_ed25519),
b'\x2B\x06\x01\x04\x01\x97\x55\x01\x05\x01':
(_parse_ed25519_pubkey, protocol.keygrip_curve25519),
}
RSA_ALGO_IDS = {1, 2, 3}
ELGAMAL_ALGO_ID = 16
DSA_ALGO_ID = 17
ECDSA_ALGO_IDS = {18, 19, 22} # {ecdsa, nist256, ed25519}
def _parse_embedded_signatures(subpackets):
for packet in subpackets:
data = bytearray(packet)
if data[0] == 32:
# https://tools.ietf.org/html/rfc4880#section-5.2.3.26
stream = io.BytesIO(data[1:])
yield _parse_signature(util.Reader(stream))
def has_custom_subpacket(signature_packet):
"""Detect our custom public keys by matching subpacket data."""
return any(protocol.CUSTOM_KEY_LABEL == subpacket[1:]
for subpacket in signature_packet['unhashed_subpackets'])
def _parse_signature(stream):
"""See https://tools.ietf.org/html/rfc4880#section-5.2 for details."""
p = {'type': 'signature'}
to_hash = io.BytesIO()
with stream.capture(to_hash):
p['version'] = stream.readfmt('B')
p['sig_type'] = stream.readfmt('B')
p['pubkey_alg'] = stream.readfmt('B')
p['hash_alg'] = stream.readfmt('B')
p['hashed_subpackets'] = parse_subpackets(stream)
# https://tools.ietf.org/html/rfc4880#section-5.2.4
tail_to_hash = b'\x04\xff' + struct.pack('>L', to_hash.tell())
p['_to_hash'] = to_hash.getvalue() + tail_to_hash
p['unhashed_subpackets'] = parse_subpackets(stream)
embedded = list(_parse_embedded_signatures(p['unhashed_subpackets']))
if embedded:
log.debug('embedded sigs: %s', embedded)
p['embedded'] = embedded
p['hash_prefix'] = stream.readfmt('2s')
if p['pubkey_alg'] in ECDSA_ALGO_IDS:
p['sig'] = (parse_mpi(stream), parse_mpi(stream))
elif p['pubkey_alg'] in RSA_ALGO_IDS: # RSA
p['sig'] = (parse_mpi(stream),)
elif p['pubkey_alg'] == DSA_ALGO_ID:
p['sig'] = (parse_mpi(stream), parse_mpi(stream))
else:
log.error('unsupported public key algo: %d', p['pubkey_alg'])
assert not stream.read()
return p
def _parse_pubkey(stream, packet_type='pubkey'):
"""See https://tools.ietf.org/html/rfc4880#section-5.5 for details."""
p = {'type': packet_type}
packet = io.BytesIO()
with stream.capture(packet):
p['version'] = stream.readfmt('B')
p['created'] = stream.readfmt('>L')
p['algo'] = stream.readfmt('B')
if p['algo'] in ECDSA_ALGO_IDS:
log.debug('parsing elliptic curve key')
# https://tools.ietf.org/html/rfc6637#section-11
oid_size = stream.readfmt('B')
oid = stream.read(oid_size)
assert oid in SUPPORTED_CURVES, util.hexlify(oid)
p['curve_oid'] = oid
mpi = parse_mpi(stream)
log.debug('mpi: %x (%d bits)', mpi, mpi.bit_length())
leftover = stream.read()
if leftover:
leftover = io.BytesIO(leftover)
# https://tools.ietf.org/html/rfc6637#section-8
# should be b'\x03\x01\x08\x07': SHA256 + AES128
size, = util.readfmt(leftover, 'B')
p['kdf'] = leftover.read(size)
p['secret'] = leftover.read()
parse_func, keygrip_func = SUPPORTED_CURVES[oid]
keygrip = keygrip_func(parse_func(mpi))
log.debug('keygrip: %s', util.hexlify(keygrip))
p['keygrip'] = keygrip
elif p['algo'] == DSA_ALGO_ID:
parse_mpis(stream, n=4) # DSA keys are not supported
elif p['algo'] == ELGAMAL_ALGO_ID:
parse_mpis(stream, n=3) # ElGamal keys are not supported
else: # assume RSA
parse_mpis(stream, n=2) # RSA keys are not supported
assert not stream.read()
# https://tools.ietf.org/html/rfc4880#section-12.2
packet_data = packet.getvalue()
data_to_hash = (b'\x99' + struct.pack('>H', len(packet_data)) +
packet_data)
p['key_id'] = hashlib.sha1(data_to_hash).digest()[-8:]
p['_to_hash'] = data_to_hash
log.debug('key ID: %s', util.hexlify(p['key_id']))
return p
_parse_subkey = functools.partial(_parse_pubkey, packet_type='subkey')
def _parse_user_id(stream, packet_type='user_id'):
"""See https://tools.ietf.org/html/rfc4880#section-5.11 for details."""
value = stream.read()
to_hash = b'\xb4' + util.prefix_len('>L', value)
return {'type': packet_type, 'value': value, '_to_hash': to_hash}
# User attribute is handled as an opaque user ID
_parse_attribute = functools.partial(_parse_user_id,
packet_type='user_attribute')
PACKET_TYPES = {
2: _parse_signature,
5: _parse_pubkey,
6: _parse_pubkey,
7: _parse_subkey,
13: _parse_user_id,
14: _parse_subkey,
17: _parse_attribute,
}
def parse_packets(stream):
"""
Support iterative parsing of available GPG packets.
See https://tools.ietf.org/html/rfc4880#section-4.2 for details.
"""
reader = util.Reader(stream)
while True:
try:
value = reader.readfmt('B')
except EOFError:
return
log.debug('prefix byte: %s', bin(value))
assert util.bit(value, 7) == 1
tag = util.low_bits(value, 6)
if util.bit(value, 6) == 0:
length_type = util.low_bits(tag, 2)
tag = tag >> 2
fmt = {0: '>B', 1: '>H', 2: '>L'}[length_type]
packet_size = reader.readfmt(fmt)
else:
first = reader.readfmt('B')
if first < 192:
packet_size = first
elif first < 224:
packet_size = ((first - 192) << 8) + reader.readfmt('B') + 192
elif first == 255:
packet_size = reader.readfmt('>L')
else:
log.error('Partial Body Lengths unsupported')
log.debug('packet length: %d', packet_size)
packet_data = reader.read(packet_size)
packet_type = PACKET_TYPES.get(tag)
p = {'type': 'unknown', 'tag': tag, 'raw': packet_data}
if packet_type is not None:
try:
p = packet_type(util.Reader(io.BytesIO(packet_data)))
p['tag'] = tag
except ValueError:
log.exception('Skipping packet: %s', util.hexlify(packet_data))
log.debug('packet "%s": %s', p['type'], p)
yield p
def digest_packets(packets, hasher):
"""Compute digest on specified packets, according to '_to_hash' field."""
data_to_hash = io.BytesIO()
for p in packets:
data_to_hash.write(p['_to_hash'])
hasher.update(data_to_hash.getvalue())
return hasher.digest()
HASH_ALGORITHMS = {
1: 'md5',
2: 'sha1',
3: 'ripemd160',
8: 'sha256',
9: 'sha384',
10: 'sha512',
11: 'sha224',
}
def load_by_keygrip(pubkey_bytes, keygrip):
"""Return public key and first user ID for specified keygrip."""
stream = io.BytesIO(pubkey_bytes)
packets = list(parse_packets(stream))
packets_per_pubkey = []
for p in packets:
if p['type'] == 'pubkey':
# Add a new packet list for each pubkey.
packets_per_pubkey.append([])
packets_per_pubkey[-1].append(p)
for packets in packets_per_pubkey:
user_ids = [p for p in packets if p['type'] == 'user_id']
for p in packets:
if p.get('keygrip') == keygrip:
return p, user_ids
raise KeyError('{} keygrip not found'.format(util.hexlify(keygrip)))
def load_signature(stream, original_data):
"""Load signature from stream, and compute GPG digest for verification."""
signature, = list(parse_packets((stream)))
hash_alg = HASH_ALGORITHMS[signature['hash_alg']]
digest = digest_packets([{'_to_hash': original_data}, signature],
hasher=hashlib.new(hash_alg))
assert signature['hash_prefix'] == digest[:2]
return signature, digest
def remove_armor(armored_data):
"""Decode armored data into its binary form."""
stream = io.BytesIO(armored_data)
lines = stream.readlines()[3:-1]
data = base64.b64decode(b''.join(lines))
payload, checksum = data[:-3], data[-3:]
assert util.crc24(payload) == checksum
return payload

103
libagent/gpg/encode.py Normal file
View File

@@ -0,0 +1,103 @@
"""Create GPG ECDSA signatures and public keys using TREZOR device."""
import io
import logging
from . import decode, keyring, protocol
from .. import util
log = logging.getLogger(__name__)
def create_primary(user_id, pubkey, signer_func, secret_bytes=b''):
"""Export new primary GPG public key, ready for "gpg2 --import"."""
pubkey_packet = protocol.packet(tag=(5 if secret_bytes else 6),
blob=(pubkey.data() + secret_bytes))
user_id_packet = protocol.packet(tag=13,
blob=user_id.encode('ascii'))
data_to_sign = (pubkey.data_to_hash() +
user_id_packet[:1] +
util.prefix_len('>L', user_id.encode('ascii')))
hashed_subpackets = [
protocol.subpacket_time(pubkey.created), # signature time
# https://tools.ietf.org/html/rfc4880#section-5.2.3.7
protocol.subpacket_byte(0x0B, 9), # preferred symmetric algo (AES-256)
# https://tools.ietf.org/html/rfc4880#section-5.2.3.4
protocol.subpacket_byte(0x1B, 1 | 2), # key flags (certify & sign)
# https://tools.ietf.org/html/rfc4880#section-5.2.3.21
protocol.subpacket_byte(0x15, 8), # preferred hash (SHA256)
# https://tools.ietf.org/html/rfc4880#section-5.2.3.8
protocol.subpacket_byte(0x16, 0), # preferred compression (none)
# https://tools.ietf.org/html/rfc4880#section-5.2.3.9
protocol.subpacket_byte(0x17, 0x80) # key server prefs (no-modify)
# https://tools.ietf.org/html/rfc4880#section-5.2.3.17
]
unhashed_subpackets = [
protocol.subpacket(16, pubkey.key_id()), # issuer key id
protocol.CUSTOM_SUBPACKET]
signature = protocol.make_signature(
signer_func=signer_func,
public_algo=pubkey.algo_id,
data_to_sign=data_to_sign,
sig_type=0x13, # user id & public key
hashed_subpackets=hashed_subpackets,
unhashed_subpackets=unhashed_subpackets)
sign_packet = protocol.packet(tag=2, blob=signature)
return pubkey_packet + user_id_packet + sign_packet
def create_subkey(primary_bytes, subkey, signer_func, secret_bytes=b''):
"""Export new subkey to GPG primary key."""
subkey_packet = protocol.packet(tag=(7 if secret_bytes else 14),
blob=(subkey.data() + secret_bytes))
packets = list(decode.parse_packets(io.BytesIO(primary_bytes)))
primary, user_id, signature = packets[:3]
data_to_sign = primary['_to_hash'] + subkey.data_to_hash()
if subkey.ecdh:
embedded_sig = None
else:
# Primary Key Binding Signature
hashed_subpackets = [
protocol.subpacket_time(subkey.created)] # signature time
unhashed_subpackets = [
protocol.subpacket(16, subkey.key_id())] # issuer key id
embedded_sig = protocol.make_signature(
signer_func=signer_func,
data_to_sign=data_to_sign,
public_algo=subkey.algo_id,
sig_type=0x19,
hashed_subpackets=hashed_subpackets,
unhashed_subpackets=unhashed_subpackets)
# Subkey Binding Signature
# Key flags: https://tools.ietf.org/html/rfc4880#section-5.2.3.21
# (certify & sign) (encrypt)
flags = (2) if (not subkey.ecdh) else (4 | 8)
hashed_subpackets = [
protocol.subpacket_time(subkey.created), # signature time
protocol.subpacket_byte(0x1B, flags)]
unhashed_subpackets = []
unhashed_subpackets.append(protocol.subpacket(16, primary['key_id']))
if embedded_sig is not None:
unhashed_subpackets.append(protocol.subpacket(32, embedded_sig))
unhashed_subpackets.append(protocol.CUSTOM_SUBPACKET)
if not decode.has_custom_subpacket(signature):
signer_func = keyring.create_agent_signer(user_id['value'])
signature = protocol.make_signature(
signer_func=signer_func,
data_to_sign=data_to_sign,
public_algo=primary['algo'],
sig_type=0x18,
hashed_subpackets=hashed_subpackets,
unhashed_subpackets=unhashed_subpackets)
sign_packet = protocol.packet(tag=2, blob=signature)
return primary_bytes + subkey_packet + sign_packet

229
libagent/gpg/keyring.py Normal file
View File

@@ -0,0 +1,229 @@
"""Tools for doing signature using gpg-agent."""
from __future__ import absolute_import, print_function, unicode_literals
import binascii
import io
import logging
import os
import re
import socket
import subprocess
from .. import util
log = logging.getLogger(__name__)
def get_agent_sock_path(sp=subprocess):
"""Parse gpgconf output to find out GPG agent UNIX socket path."""
lines = sp.check_output(['gpgconf', '--list-dirs']).strip().split(b'\n')
dirs = dict(line.split(b':', 1) for line in lines)
return dirs[b'agent-socket']
def connect_to_agent(sp=subprocess):
"""Connect to GPG agent's UNIX socket."""
sock_path = get_agent_sock_path(sp=sp)
sp.check_call(['gpg-connect-agent', '/bye']) # Make sure it's running
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(sock_path)
return sock
def communicate(sock, msg):
"""Send a message and receive a single line."""
sendline(sock, msg.encode('ascii'))
return recvline(sock)
def sendline(sock, msg):
"""Send a binary message, followed by EOL."""
log.debug('<- %r', msg)
sock.sendall(msg + b'\n')
def recvline(sock):
"""Receive a single line from the socket."""
reply = io.BytesIO()
while True:
c = sock.recv(1)
if not c:
return None # socket is closed
if c == b'\n':
break
reply.write(c)
result = reply.getvalue()
log.debug('-> %r', result)
return result
def iterlines(conn):
"""Iterate over input, split by lines."""
while True:
line = recvline(conn)
if line is None:
break
yield line
def unescape(s):
"""Unescape ASSUAN message (0xAB <-> '%AB')."""
s = bytearray(s)
i = 0
while i < len(s):
if s[i] == ord('%'):
hex_bytes = bytes(s[i+1:i+3])
value = int(hex_bytes.decode('ascii'), 16)
s[i:i+3] = [value]
i += 1
return bytes(s)
def parse_term(s):
"""Parse single s-expr term from bytes."""
size, s = s.split(b':', 1)
size = int(size)
return s[:size], s[size:]
def parse(s):
"""Parse full s-expr from bytes."""
if s.startswith(b'('):
s = s[1:]
name, s = parse_term(s)
values = [name]
while not s.startswith(b')'):
value, s = parse(s)
values.append(value)
return values, s[1:]
else:
return parse_term(s)
def _parse_ecdsa_sig(args):
(r, sig_r), (s, sig_s) = args
assert r == b'r'
assert s == b's'
return (util.bytes2num(sig_r),
util.bytes2num(sig_s))
# DSA and EDDSA happen to have the same structure as ECDSA signatures
_parse_dsa_sig = _parse_ecdsa_sig
_parse_eddsa_sig = _parse_ecdsa_sig
def _parse_rsa_sig(args):
(s, sig_s), = args
assert s == b's'
return (util.bytes2num(sig_s),)
def parse_sig(sig):
"""Parse signature integer values from s-expr."""
label, sig = sig
assert label == b'sig-val'
algo_name = sig[0]
parser = {b'rsa': _parse_rsa_sig,
b'ecdsa': _parse_ecdsa_sig,
b'eddsa': _parse_eddsa_sig,
b'dsa': _parse_dsa_sig}[algo_name]
return parser(args=sig[1:])
def sign_digest(sock, keygrip, digest, sp=subprocess, environ=None):
"""Sign a digest using specified key using GPG agent."""
hash_algo = 8 # SHA256
assert len(digest) == 32
assert communicate(sock, 'RESET').startswith(b'OK')
ttyname = sp.check_output(['tty']).strip()
options = ['ttyname={}'.format(ttyname)] # set TTY for passphrase entry
display = (environ or os.environ).get('DISPLAY')
if display is not None:
options.append('display={}'.format(display))
for opt in options:
assert communicate(sock, 'OPTION {}'.format(opt)) == b'OK'
assert communicate(sock, 'SIGKEY {}'.format(keygrip)) == b'OK'
hex_digest = binascii.hexlify(digest).upper().decode('ascii')
assert communicate(sock, 'SETHASH {} {}'.format(hash_algo,
hex_digest)) == b'OK'
assert communicate(sock, 'SETKEYDESC '
'Sign+a+new+TREZOR-based+subkey') == b'OK'
assert communicate(sock, 'PKSIGN') == b'OK'
while True:
line = recvline(sock).strip()
if line.startswith(b'S PROGRESS'):
continue
else:
break
line = unescape(line)
log.debug('unescaped: %r', line)
prefix, sig = line.split(b' ', 1)
if prefix != b'D':
raise ValueError(prefix)
sig, leftover = parse(sig)
assert not leftover, leftover
return parse_sig(sig)
def gpg_command(args, env=None):
"""Prepare common GPG command line arguments."""
if env is None:
env = os.environ
cmd = ['gpg2']
homedir = env.get('GNUPGHOME')
if homedir:
cmd.extend(['--homedir', homedir])
return cmd + args
def get_keygrip(user_id, sp=subprocess):
"""Get a keygrip of the primary GPG key of the specified user."""
args = gpg_command(['--list-keys', '--with-keygrip', user_id])
output = sp.check_output(args).decode('ascii')
return re.findall(r'Keygrip = (\w+)', output)[0]
def gpg_version(sp=subprocess):
"""Get a keygrip of the primary GPG key of the specified user."""
args = gpg_command(['--version'])
output = sp.check_output(args)
line = output.split(b'\n')[0] # b'gpg (GnuPG) 2.1.11'
return line.split(b' ')[-1] # b'2.1.11'
def export_public_key(user_id, sp=subprocess):
"""Export GPG public key for specified `user_id`."""
args = gpg_command(['--export', user_id])
result = sp.check_output(args=args)
if not result:
log.error('could not find public key %r in local GPG keyring', user_id)
raise KeyError(user_id)
return result
def export_public_keys(sp=subprocess):
"""Export all GPG public keys."""
args = gpg_command(['--export'])
return sp.check_output(args=args)
def create_agent_signer(user_id):
"""Sign digest with existing GPG keys using gpg-agent tool."""
sock = connect_to_agent()
keygrip = get_keygrip(user_id)
def sign(digest):
"""Sign the digest and return an ECDSA/RSA/DSA signature."""
return sign_digest(sock=sock, keygrip=keygrip, digest=digest)
return sign

271
libagent/gpg/protocol.py Normal file
View File

@@ -0,0 +1,271 @@
"""GPG protocol utilities."""
import base64
import hashlib
import logging
import struct
from .. import formats, util
log = logging.getLogger(__name__)
def packet(tag, blob):
"""Create small GPG packet."""
assert len(blob) < 2**32
if len(blob) < 2**8:
length_type = 0
elif len(blob) < 2**16:
length_type = 1
else:
length_type = 2
fmt = ['>B', '>H', '>L'][length_type]
leading_byte = 0x80 | (tag << 2) | (length_type)
return struct.pack('>B', leading_byte) + util.prefix_len(fmt, blob)
def subpacket(subpacket_type, fmt, *values):
"""Create GPG subpacket."""
blob = struct.pack(fmt, *values) if values else fmt
return struct.pack('>B', subpacket_type) + blob
def subpacket_long(subpacket_type, value):
"""Create GPG subpacket with 32-bit unsigned integer."""
return subpacket(subpacket_type, '>L', value)
def subpacket_time(value):
"""Create GPG subpacket with time in seconds (since Epoch)."""
return subpacket_long(2, value)
def subpacket_byte(subpacket_type, value):
"""Create GPG subpacket with 8-bit unsigned integer."""
return subpacket(subpacket_type, '>B', value)
def subpacket_prefix_len(item):
"""Prefix subpacket length according to RFC 4880 section-5.2.3.1."""
n = len(item)
if n >= 8384:
prefix = b'\xFF' + struct.pack('>L', n)
elif n >= 192:
n = n - 192
prefix = struct.pack('BB', (n // 256) + 192, n % 256)
else:
prefix = struct.pack('B', n)
return prefix + item
def subpackets(*items):
"""Serialize several GPG subpackets."""
prefixed = [subpacket_prefix_len(item) for item in items]
return util.prefix_len('>H', b''.join(prefixed))
def mpi(value):
"""Serialize multipresicion integer using GPG format."""
bits = value.bit_length()
data_size = (bits + 7) // 8
data_bytes = bytearray(data_size)
for i in range(data_size):
data_bytes[i] = value & 0xFF
value = value >> 8
data_bytes.reverse()
return struct.pack('>H', bits) + bytes(data_bytes)
def _serialize_nist256(vk):
return mpi((4 << 512) |
(vk.pubkey.point.x() << 256) |
(vk.pubkey.point.y()))
def _serialize_ed25519(vk):
return mpi((0x40 << 256) |
util.bytes2num(vk.to_bytes()))
def _compute_keygrip(params):
parts = []
for name, value in params:
exp = '{}:{}{}:'.format(len(name), name, len(value))
parts.append(b'(' + exp.encode('ascii') + value + b')')
return hashlib.sha1(b''.join(parts)).digest()
def keygrip_nist256(vk):
"""Compute keygrip for NIST256 curve public keys."""
curve = vk.curve.curve
gen = vk.curve.generator
g = (4 << 512) | (gen.x() << 256) | gen.y()
point = vk.pubkey.point
q = (4 << 512) | (point.x() << 256) | point.y()
return _compute_keygrip([
['p', util.num2bytes(curve.p(), size=32)],
['a', util.num2bytes(curve.a() % curve.p(), size=32)],
['b', util.num2bytes(curve.b() % curve.p(), size=32)],
['g', util.num2bytes(g, size=65)],
['n', util.num2bytes(vk.curve.order, size=32)],
['q', util.num2bytes(q, size=65)],
])
def keygrip_ed25519(vk):
"""Compute keygrip for Ed25519 public keys."""
# pylint: disable=line-too-long
return _compute_keygrip([
['p', util.num2bytes(0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED, size=32)], # nopep8
['a', b'\x01'],
['b', util.num2bytes(0x2DFC9311D490018C7338BF8688861767FF8FF5B2BEBE27548A14B235ECA6874A, size=32)], # nopep8
['g', util.num2bytes(0x04216936D3CD6E53FEC0A4E231FDD6DC5C692CC7609525A7B2C9562D608F25D51A6666666666666666666666666666666666666666666666666666666666666658, size=65)], # nopep8
['n', util.num2bytes(0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED, size=32)], # nopep8
['q', vk.to_bytes()],
])
def keygrip_curve25519(vk):
"""Compute keygrip for Curve25519 public keys."""
# pylint: disable=line-too-long
return _compute_keygrip([
['p', util.num2bytes(0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED, size=32)], # nopep8
['a', b'\x01\xDB\x41'],
['b', b'\x01'],
['g', util.num2bytes(0x04000000000000000000000000000000000000000000000000000000000000000920ae19a1b8a086b4e01edd2c7748d14c923d4d7e6d7c61b229e9c5a27eced3d9, size=65)], # nopep8
['n', util.num2bytes(0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED, size=32)], # nopep8
['q', vk.to_bytes()],
])
SUPPORTED_CURVES = {
formats.CURVE_NIST256: {
# https://tools.ietf.org/html/rfc6637#section-11
'oid': b'\x2A\x86\x48\xCE\x3D\x03\x01\x07',
'algo_id': 19,
'serialize': _serialize_nist256,
'keygrip': keygrip_nist256,
},
formats.CURVE_ED25519: {
'oid': b'\x2B\x06\x01\x04\x01\xDA\x47\x0F\x01',
'algo_id': 22,
'serialize': _serialize_ed25519,
'keygrip': keygrip_ed25519,
},
formats.ECDH_CURVE25519: {
'oid': b'\x2B\x06\x01\x04\x01\x97\x55\x01\x05\x01',
'algo_id': 18,
'serialize': _serialize_ed25519,
'keygrip': keygrip_curve25519,
},
}
ECDH_ALGO_ID = 18
CUSTOM_KEY_LABEL = b'TREZOR-GPG' # marks "our" pubkey
CUSTOM_SUBPACKET_ID = 26 # use "policy URL" subpacket
CUSTOM_SUBPACKET = subpacket(CUSTOM_SUBPACKET_ID, CUSTOM_KEY_LABEL)
def get_curve_name_by_oid(oid):
"""Return curve name matching specified OID, or raise KeyError."""
for curve_name, info in SUPPORTED_CURVES.items():
if info['oid'] == oid:
return curve_name
raise KeyError('Unknown OID: {!r}'.format(oid))
class PublicKey(object):
"""GPG representation for public key packets."""
def __init__(self, curve_name, created, verifying_key, ecdh=False):
"""Contruct using a ECDSA VerifyingKey object."""
self.curve_name = curve_name
self.curve_info = SUPPORTED_CURVES[curve_name]
self.created = int(created) # time since Epoch
self.verifying_key = verifying_key
self.ecdh = bool(ecdh)
if ecdh:
self.algo_id = ECDH_ALGO_ID
self.ecdh_packet = b'\x03\x01\x08\x07'
else:
self.algo_id = self.curve_info['algo_id']
self.ecdh_packet = b''
def keygrip(self):
"""Compute GPG keygrip of the verifying key."""
return self.curve_info['keygrip'](self.verifying_key)
def data(self):
"""Data for packet creation."""
header = struct.pack('>BLB',
4, # version
self.created, # creation
self.algo_id) # public key algorithm ID
oid = util.prefix_len('>B', self.curve_info['oid'])
blob = self.curve_info['serialize'](self.verifying_key)
return header + oid + blob + self.ecdh_packet
def data_to_hash(self):
"""Data for digest computation."""
return b'\x99' + util.prefix_len('>H', self.data())
def _fingerprint(self):
return hashlib.sha1(self.data_to_hash()).digest()
def key_id(self):
"""Short (8 byte) GPG key ID."""
return self._fingerprint()[-8:]
def __repr__(self):
"""Short (8 hexadecimal digits) GPG key ID."""
hex_key_id = util.hexlify(self.key_id())[-8:]
return 'GPG public key {}/{}'.format(self.curve_name, hex_key_id)
__str__ = __repr__
def _split_lines(body, size):
lines = []
for i in range(0, len(body), size):
lines.append(body[i:i+size] + '\n')
return ''.join(lines)
def armor(blob, type_str):
"""See https://tools.ietf.org/html/rfc4880#section-6 for details."""
head = '-----BEGIN PGP {}-----\nVersion: GnuPG v2\n\n'.format(type_str)
body = base64.b64encode(blob).decode('ascii')
checksum = base64.b64encode(util.crc24(blob)).decode('ascii')
tail = '-----END PGP {}-----\n'.format(type_str)
return head + _split_lines(body, 64) + '=' + checksum + '\n' + tail
def make_signature(signer_func, data_to_sign, public_algo,
hashed_subpackets, unhashed_subpackets, sig_type=0):
"""Create new GPG signature."""
# pylint: disable=too-many-arguments
header = struct.pack('>BBBB',
4, # version
sig_type, # rfc4880 (section-5.2.1)
public_algo,
8) # hash_alg (SHA256)
hashed = subpackets(*hashed_subpackets)
unhashed = subpackets(*unhashed_subpackets)
tail = b'\x04\xff' + struct.pack('>L', len(header) + len(hashed))
data_to_hash = data_to_sign + header + hashed + tail
log.debug('hashing %d bytes', len(data_to_hash))
digest = hashlib.sha256(data_to_hash).digest()
log.debug('signing digest: %s', util.hexlify(digest))
params = signer_func(digest=digest)
sig = b''.join(mpi(p) for p in params)
return bytes(header + hashed + unhashed +
digest[:2] + # used for decoder's sanity check
sig) # actual ECDSA signature

View File

@@ -0,0 +1 @@
"""Tests for GPG module."""

View File

@@ -0,0 +1,62 @@
import glob
import io
import os
import pytest
from .. import decode, protocol
from ... import util
def test_subpackets():
s = io.BytesIO(b'\x00\x05\x02\xAB\xCD\x01\xEF')
assert decode.parse_subpackets(util.Reader(s)) == [b'\xAB\xCD', b'\xEF']
def test_subpackets_prefix():
for n in [0, 1, 2, 4, 5, 10, 191, 192, 193,
255, 256, 257, 8383, 8384, 65530]:
item = b'?' * n # create dummy subpacket
prefixed = protocol.subpackets(item)
result = decode.parse_subpackets(util.Reader(io.BytesIO(prefixed)))
assert [item] == result
def test_mpi():
s = io.BytesIO(b'\x00\x09\x01\x23')
assert decode.parse_mpi(util.Reader(s)) == 0x123
s = io.BytesIO(b'\x00\x09\x01\x23\x00\x03\x05')
assert decode.parse_mpis(util.Reader(s), n=2) == [0x123, 5]
cwd = os.path.join(os.path.dirname(__file__))
input_files = glob.glob(os.path.join(cwd, '*.gpg'))
@pytest.fixture(params=input_files)
def public_key_path(request):
return request.param
def test_gpg_files(public_key_path): # pylint: disable=redefined-outer-name
with open(public_key_path, 'rb') as f:
assert list(decode.parse_packets(f))
def test_has_custom_subpacket():
sig = {'unhashed_subpackets': []}
assert not decode.has_custom_subpacket(sig)
custom_markers = [
protocol.CUSTOM_SUBPACKET,
protocol.subpacket(10, protocol.CUSTOM_KEY_LABEL),
]
for marker in custom_markers:
sig = {'unhashed_subpackets': [marker]}
assert decode.has_custom_subpacket(sig)
def test_load_by_keygrip_missing():
with pytest.raises(KeyError):
decode.load_by_keygrip(pubkey_bytes=b'', keygrip=b'')

View File

@@ -0,0 +1,101 @@
import io
import mock
from .. import keyring
def test_unescape_short():
assert keyring.unescape(b'abc%0AX%0D %25;.-+()') == b'abc\nX\r %;.-+()'
def test_unescape_long():
escaped = (b'D (7:sig-val(3:dsa(1:r32:\x1d\x15.\x12\xe8h\x19\xd9O\xeb\x06'
b'yD?a:/\xae\xdb\xac\x93\xa6\x86\xcbs\xb8\x03\xf1\xcb\x89\xc7'
b'\x1f)(1:s32:%25\xb5\x04\x94\xc7\xc4X\xc7\xe0%0D\x08\xbb%0DuN'
b'\x9c6}[\xc2=t\x8c\xfdD\x81\xe8\xdd\x86=\xe2\xa9)))')
unescaped = (b'D (7:sig-val(3:dsa(1:r32:\x1d\x15.\x12\xe8h\x19\xd9O\xeb'
b'\x06yD?a:/\xae\xdb\xac\x93\xa6\x86\xcbs\xb8\x03\xf1\xcb\x89'
b'\xc7\x1f)(1:s32:%\xb5\x04\x94\xc7\xc4X\xc7\xe0\r\x08\xbb\ru'
b'N\x9c6}[\xc2=t\x8c\xfdD\x81\xe8\xdd\x86=\xe2\xa9)))')
assert keyring.unescape(escaped) == unescaped
def test_parse_term():
assert keyring.parse(b'4:abcdXXX') == (b'abcd', b'XXX')
def test_parse_ecdsa():
sig, rest = keyring.parse(b'(7:sig-val(5:ecdsa'
b'(1:r2:\x01\x02)(1:s2:\x03\x04)))')
values = [[b'r', b'\x01\x02'], [b's', b'\x03\x04']]
assert sig == [b'sig-val', [b'ecdsa'] + values]
assert rest == b''
assert keyring.parse_sig(sig) == (0x102, 0x304)
def test_parse_rsa():
sig, rest = keyring.parse(b'(7:sig-val(3:rsa(1:s4:\x01\x02\x03\x04)))')
assert sig == [b'sig-val', [b'rsa', [b's', b'\x01\x02\x03\x04']]]
assert rest == b''
assert keyring.parse_sig(sig) == (0x1020304,)
class FakeSocket(object):
def __init__(self):
self.rx = io.BytesIO()
self.tx = io.BytesIO()
def recv(self, n):
return self.rx.read(n)
def sendall(self, data):
self.tx.write(data)
def test_sign_digest():
sock = FakeSocket()
sock.rx.write(b'OK Pleased to meet you, process XYZ\n')
sock.rx.write(b'OK\n' * 6)
sock.rx.write(b'D (7:sig-val(3:rsa(1:s16:0123456789ABCDEF)))\n')
sock.rx.seek(0)
keygrip = '1234'
digest = b'A' * 32
sp = mock.Mock(spec=['check_output'])
sp.check_output.return_value = '/dev/pts/0'
sig = keyring.sign_digest(sock=sock, keygrip=keygrip,
digest=digest, sp=sp,
environ={'DISPLAY': ':0'})
assert sig == (0x30313233343536373839414243444546,)
assert sock.tx.getvalue() == b'''RESET
OPTION ttyname=/dev/pts/0
OPTION display=:0
SIGKEY 1234
SETHASH 8 4141414141414141414141414141414141414141414141414141414141414141
SETKEYDESC Sign+a+new+TREZOR-based+subkey
PKSIGN
'''
def test_iterlines():
sock = FakeSocket()
sock.rx.write(b'foo\nbar\nxyz')
sock.rx.seek(0)
assert list(keyring.iterlines(sock)) == [b'foo', b'bar']
def test_get_agent_sock_path():
sp = mock.Mock(spec=['check_output'])
sp.check_output.return_value = b'''sysconfdir:/usr/local/etc/gnupg
bindir:/usr/local/bin
libexecdir:/usr/local/libexec
libdir:/usr/local/lib/gnupg
datadir:/usr/local/share/gnupg
localedir:/usr/local/share/locale
dirmngr-socket:/run/user/1000/gnupg/S.dirmngr
agent-ssh-socket:/run/user/1000/gnupg/S.gpg-agent.ssh
agent-socket:/run/user/1000/gnupg/S.gpg-agent
homedir:/home/roman/.gnupg
'''
expected = b'/run/user/1000/gnupg/S.gpg-agent'
assert keyring.get_agent_sock_path(sp=sp) == expected

View File

@@ -0,0 +1,107 @@
import ecdsa
import ed25519
import pytest
from .. import protocol
from ... import formats
def test_packet():
assert protocol.packet(1, b'') == b'\x84\x00'
assert protocol.packet(2, b'A') == b'\x88\x01A'
blob = b'B' * 0xAB
assert protocol.packet(3, blob) == b'\x8c\xAB' + blob
blob = b'C' * 0x1234
assert protocol.packet(3, blob) == b'\x8d\x12\x34' + blob
blob = b'D' * 0x12345678
assert protocol.packet(4, blob) == b'\x92\x12\x34\x56\x78' + blob
def test_subpackets():
assert protocol.subpacket(1, b'') == b'\x01'
assert protocol.subpacket(2, '>H', 0x0304) == b'\x02\x03\x04'
assert protocol.subpacket_long(9, 0x12345678) == b'\x09\x12\x34\x56\x78'
assert protocol.subpacket_time(0x12345678) == b'\x02\x12\x34\x56\x78'
assert protocol.subpacket_byte(0xAB, 0xCD) == b'\xAB\xCD'
assert protocol.subpackets() == b'\x00\x00'
assert protocol.subpackets(b'ABC', b'12345') == b'\x00\x0A\x03ABC\x0512345'
def test_mpi():
assert protocol.mpi(0x123) == b'\x00\x09\x01\x23'
def test_armor():
data = bytearray(range(256))
assert protocol.armor(data, 'TEST') == '''-----BEGIN PGP TEST-----
Version: GnuPG v2
AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4v
MDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5f
YGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouMjY6P
kJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/
wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v
8PHy8/T19vf4+fr7/P3+/w==
=W700
-----END PGP TEST-----
'''
def test_make_signature():
def signer_func(digest):
assert digest == (b'\xd0\xe5]|\x8bP\xe6\x91\xb3\xe8+\xf4A\xf0`(\xb1'
b'\xc7\xf4;\x86\x97s\xdb\x9a\xda\xee< \xcb\x9e\x00')
return (7, 8)
sig = protocol.make_signature(
signer_func=signer_func,
data_to_sign=b'Hello World!',
public_algo=22,
hashed_subpackets=[protocol.subpacket_time(1)],
unhashed_subpackets=[],
sig_type=25)
assert sig == (b'\x04\x19\x16\x08\x00\x06\x05\x02'
b'\x00\x00\x00\x01\x00\x00\xd0\xe5\x00\x03\x07\x00\x04\x08')
def test_nist256p1():
sk = ecdsa.SigningKey.from_secret_exponent(secexp=1, curve=ecdsa.NIST256p)
vk = sk.get_verifying_key()
pk = protocol.PublicKey(curve_name=formats.CURVE_NIST256,
created=42, verifying_key=vk)
assert repr(pk) == 'GPG public key nist256p1/F82361D9'
assert pk.keygrip() == b'\x95\x85.\x91\x7f\xe2\xc3\x91R\xba\x99\x81\x92\xb5y\x1d\xb1\\\xdc\xf0'
def test_nist256p1_ecdh():
sk = ecdsa.SigningKey.from_secret_exponent(secexp=1, curve=ecdsa.NIST256p)
vk = sk.get_verifying_key()
pk = protocol.PublicKey(curve_name=formats.CURVE_NIST256,
created=42, verifying_key=vk, ecdh=True)
assert repr(pk) == 'GPG public key nist256p1/5811DF46'
assert pk.keygrip() == b'\x95\x85.\x91\x7f\xe2\xc3\x91R\xba\x99\x81\x92\xb5y\x1d\xb1\\\xdc\xf0'
def test_ed25519():
sk = ed25519.SigningKey(b'\x00' * 32)
vk = sk.get_verifying_key()
pk = protocol.PublicKey(curve_name=formats.CURVE_ED25519,
created=42, verifying_key=vk)
assert repr(pk) == 'GPG public key ed25519/36B40FE6'
assert pk.keygrip() == b'\xbf\x01\x90l\x17\xb64\xa3-\xf4\xc0gr\x99\x18<\xddBQ?'
def test_curve25519():
sk = ed25519.SigningKey(b'\x00' * 32)
vk = sk.get_verifying_key()
pk = protocol.PublicKey(curve_name=formats.ECDH_CURVE25519,
created=42, verifying_key=vk)
assert repr(pk) == 'GPG public key curve25519/69460384'
assert pk.keygrip() == b'x\xd6\x86\xe4\xa6\xfc;\x0fY\xe1}Lw\xc4\x9ed\xf1Q\x8a\x00'
def test_get_curve_name_by_oid():
for name, info in protocol.SUPPORTED_CURVES.items():
assert protocol.get_curve_name_by_oid(info['oid']) == name
with pytest.raises(KeyError):
protocol.get_curve_name_by_oid('BAD_OID')

160
libagent/protocol.py Normal file
View File

@@ -0,0 +1,160 @@
"""
SSH-agent protocol implementation library.
See https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.agent and
http://ptspts.blogspot.co.il/2010/06/how-to-use-ssh-agent-programmatically.html
for more details.
The server's source code can be found here:
https://github.com/openssh/openssh-portable/blob/master/authfd.c
"""
import io
import logging
from . import formats, util
log = logging.getLogger(__name__)
# Taken from https://github.com/openssh/openssh-portable/blob/master/authfd.h
COMMANDS = dict(
SSH_AGENTC_REQUEST_RSA_IDENTITIES=1,
SSH_AGENT_RSA_IDENTITIES_ANSWER=2,
SSH_AGENTC_RSA_CHALLENGE=3,
SSH_AGENT_RSA_RESPONSE=4,
SSH_AGENT_FAILURE=5,
SSH_AGENT_SUCCESS=6,
SSH_AGENTC_ADD_RSA_IDENTITY=7,
SSH_AGENTC_REMOVE_RSA_IDENTITY=8,
SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES=9,
SSH2_AGENTC_REQUEST_IDENTITIES=11,
SSH2_AGENT_IDENTITIES_ANSWER=12,
SSH2_AGENTC_SIGN_REQUEST=13,
SSH2_AGENT_SIGN_RESPONSE=14,
SSH2_AGENTC_ADD_IDENTITY=17,
SSH2_AGENTC_REMOVE_IDENTITY=18,
SSH2_AGENTC_REMOVE_ALL_IDENTITIES=19,
SSH_AGENTC_ADD_SMARTCARD_KEY=20,
SSH_AGENTC_REMOVE_SMARTCARD_KEY=21,
SSH_AGENTC_LOCK=22,
SSH_AGENTC_UNLOCK=23,
SSH_AGENTC_ADD_RSA_ID_CONSTRAINED=24,
SSH2_AGENTC_ADD_ID_CONSTRAINED=25,
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED=26,
)
def msg_code(name):
"""Convert string name into a integer message code."""
return COMMANDS[name]
def msg_name(code):
"""Convert integer message code into a string name."""
ids = {v: k for k, v in COMMANDS.items()}
return ids[code]
def failure():
"""Return error code to SSH binary."""
error_msg = util.pack('B', msg_code('SSH_AGENT_FAILURE'))
return util.frame(error_msg)
def _legacy_pubs(buf):
"""SSH v1 public keys are not supported."""
leftover = buf.read()
if leftover:
log.warning('skipping leftover: %r', leftover)
code = util.pack('B', msg_code('SSH_AGENT_RSA_IDENTITIES_ANSWER'))
num = util.pack('L', 0) # no SSH v1 keys
return util.frame(code, num)
class Handler(object):
"""ssh-agent protocol handler."""
def __init__(self, conn, debug=False):
"""
Create a protocol handler with specified public keys.
Use specified signer function to sign SSH authentication requests.
"""
self.conn = conn
self.debug = debug
self.methods = {
msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES'): _legacy_pubs,
msg_code('SSH2_AGENTC_REQUEST_IDENTITIES'): self.list_pubs,
msg_code('SSH2_AGENTC_SIGN_REQUEST'): self.sign_message,
}
def handle(self, msg):
"""Handle SSH message from the SSH client and return the response."""
debug_msg = ': {!r}'.format(msg) if self.debug else ''
log.debug('request: %d bytes%s', len(msg), debug_msg)
buf = io.BytesIO(msg)
code, = util.recv(buf, '>B')
if code not in self.methods:
log.warning('Unsupported command: %s (%d)', msg_name(code), code)
return failure()
method = self.methods[code]
log.debug('calling %s()', method.__name__)
reply = method(buf=buf)
debug_reply = ': {!r}'.format(reply) if self.debug else ''
log.debug('reply: %d bytes%s', len(reply), debug_reply)
return reply
def list_pubs(self, buf):
"""SSH v2 public keys are serialized and returned."""
assert not buf.read()
keys = self.conn.parse_public_keys()
code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER'))
num = util.pack('L', len(keys))
log.debug('available keys: %s', [k['name'] for k in keys])
for i, k in enumerate(keys):
log.debug('%2d) %s', i+1, k['fingerprint'])
pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys]
return util.frame(code, num, *pubs)
def sign_message(self, buf):
"""
SSH v2 public key authentication is performed.
If the required key is not supported, raise KeyError
If the signature is invalid, raise ValueError
"""
key = formats.parse_pubkey(util.read_frame(buf))
log.debug('looking for %s', key['fingerprint'])
blob = util.read_frame(buf)
assert util.read_frame(buf) == b''
assert not buf.read()
for k in self.conn.parse_public_keys():
if (k['fingerprint']) == (key['fingerprint']):
log.debug('using key %r (%s)', k['name'], k['fingerprint'])
key = k
break
else:
raise KeyError('key not found')
label = key['name'].decode('ascii') # label should be a string
log.debug('signing %d-byte blob with "%s" key', len(blob), label)
try:
signature = self.conn.sign(blob=blob, identity=key['identity'])
except IOError:
return failure()
log.debug('signature: %r', signature)
try:
sig_bytes = key['verifier'](sig=signature, msg=blob)
log.info('signature status: OK')
except formats.ecdsa.BadSignatureError:
log.exception('signature status: ERROR')
raise ValueError('invalid ECDSA signature')
log.debug('signature size: %d bytes', len(sig_bytes))
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))
code = util.pack('B', msg_code('SSH2_AGENT_SIGN_RESPONSE'))
return util.frame(code, data)

166
libagent/server.py Normal file
View File

@@ -0,0 +1,166 @@
"""UNIX-domain socket server for ssh-agent implementation."""
import contextlib
import functools
import logging
import os
import socket
import subprocess
import tempfile
import threading
from . import util
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file(path, remove=os.remove, exists=os.path.exists):
"""Remove file, and raise OSError if still exists."""
try:
remove(path)
except OSError:
if exists(path):
raise
@contextlib.contextmanager
def unix_domain_socket_server(sock_path):
"""
Create UNIX-domain socket on specified path.
Listen on it, and delete it after the generated context is over.
"""
log.debug('serving on %s', sock_path)
remove_file(sock_path)
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(sock_path)
server.listen(1)
try:
yield server
finally:
remove_file(sock_path)
def handle_connection(conn, handler, mutex):
"""
Handle a single connection using the specified protocol handler in a loop.
Since this function may be called concurrently from server_thread,
the specified mutex is used to synchronize the device handling.
Exit when EOFError is raised.
All other exceptions are logged as warnings.
"""
try:
log.debug('welcome agent')
with contextlib.closing(conn):
while True:
msg = util.read_frame(conn)
with mutex:
reply = handler.handle(msg=msg)
util.send(conn, reply)
except EOFError:
log.debug('goodbye agent')
except Exception as e: # pylint: disable=broad-except
log.warning('error: %s', e, exc_info=True)
def retry(func, exception_type, quit_event):
"""
Run the function, retrying when the specified exception_type occurs.
Poll quit_event on each iteration, to be responsive to an external
exit request.
"""
while True:
if quit_event.is_set():
raise StopIteration
try:
return func()
except exception_type:
pass
def server_thread(sock, handle_conn, quit_event):
"""Run a server on the specified socket."""
log.debug('server thread started')
def accept_connection():
conn, _ = sock.accept()
conn.settimeout(None)
return conn
while True:
log.debug('waiting for connection on %s', sock.getsockname())
try:
conn = retry(accept_connection, socket.timeout, quit_event)
except StopIteration:
log.debug('server stopped')
break
# Handle connections from SSH concurrently.
threading.Thread(target=handle_conn,
kwargs=dict(conn=conn)).start()
log.debug('server thread stopped')
@contextlib.contextmanager
def spawn(func, kwargs):
"""Spawn a thread, and join it after the context is over."""
t = threading.Thread(target=func, kwargs=kwargs)
t.start()
yield
t.join()
@contextlib.contextmanager
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
if sock_path is None:
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
device_mutex = threading.Lock()
with unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
handle_conn = functools.partial(handle_connection,
handler=handler,
mutex=device_mutex)
kwargs = dict(sock=sock,
handle_conn=handle_conn,
quit_event=quit_event)
with spawn(server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
def run_process(command, environ):
"""
Run the specified process and wait until it finishes.
Use environ dict for environment variables.
"""
log.info('running %r with %r', command, environ)
env = dict(os.environ)
env.update(environ)
try:
p = subprocess.Popen(args=command, env=env)
except OSError as e:
raise OSError('cannot run %r: %s' % (command, e))
log.debug('subprocess %d is running', p.pid)
ret = p.wait()
log.debug('subprocess %d exited: %d', p.pid, ret)
return ret

205
libagent/ssh/__init__.py Normal file
View File

@@ -0,0 +1,205 @@
"""SSH-agent implementation using hardware authentication devices."""
import argparse
import functools
import logging
import os
import re
import subprocess
import sys
from .. import client, device, formats, protocol, server, util
log = logging.getLogger(__name__)
def ssh_args(label):
"""Create SSH command for connecting specified server."""
identity = device.interface.string_to_identity(label)
args = []
if 'port' in identity:
args += ['-p', identity['port']]
if 'user' in identity:
args += ['-l', identity['user']]
return args + [identity['host']]
def mosh_args(label):
"""Create SSH command for connecting specified server."""
identity = device.interface.string_to_identity(label)
args = []
if 'port' in identity:
args += ['-p', identity['port']]
if 'user' in identity:
args += [identity['user']+'@'+identity['host']]
else:
args += [identity['host']]
return args
def create_parser():
"""Create argparse.ArgumentParser for this tool."""
p = argparse.ArgumentParser()
p.add_argument('-v', '--verbose', default=0, action='count')
curve_names = [name for name in formats.SUPPORTED_CURVES]
curve_names = ', '.join(sorted(curve_names))
p.add_argument('-e', '--ecdsa-curve-name', metavar='CURVE',
default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout',
default=server.UNIX_SOCKET_TIMEOUT, type=float,
help='Timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true',
help='Log SSH protocol messages for debugging.')
return p
def create_agent_parser():
"""Specific parser for SSH connection."""
p = create_parser()
g = p.add_mutually_exclusive_group()
g.add_argument('-s', '--shell', default=False, action='store_true',
help='run ${SHELL} as subprocess under SSH agent')
g.add_argument('-c', '--connect', default=False, action='store_true',
help='connect to specified host via SSH')
g.add_argument('--mosh', default=False, action='store_true',
help='connect to specified host via using Mosh')
p.add_argument('identity', type=str, default=None,
help='proto://[user@]host[:port][/path]')
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='command to run under the SSH agent')
return p
def create_git_parser():
"""Specific parser for git commands."""
p = create_parser()
p.add_argument('-r', '--remote', default='origin',
help='use this git remote URL to generate SSH identity')
p.add_argument('-t', '--test', action='store_true',
help='test connection using `ssh -T user@host` command')
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='Git command to run under the SSH agent')
return p
def git_host(remote_name, attributes):
"""Extract git SSH host for specified remote name."""
try:
output = subprocess.check_output('git config --local --list'.split())
except subprocess.CalledProcessError:
return
for attribute in attributes:
name = r'remote.{0}.{1}'.format(remote_name, attribute)
matches = re.findall(re.escape(name) + '=(.*)', output)
log.debug('%r: %r', name, matches)
if not matches:
continue
url = matches[0].strip()
match = re.match('(?P<user>.*?)@(?P<host>.*?):(?P<path>.*)', url)
if match:
return '{user}@{host}'.format(**match.groupdict())
def run_server(conn, command, debug, timeout):
"""Common code for run_agent and run_git below."""
try:
handler = protocol.Handler(conn=conn, debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
except KeyboardInterrupt:
log.info('server stopped')
def handle_connection_error(func):
"""Fail with non-zero exit code."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except IOError as e:
log.error('Connection error (try unplugging and replugging your device): %s', e)
return 1
return wrapper
def parse_config(fname):
"""Parse config file into a list of Identity objects."""
contents = open(fname).read()
for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
yield device.interface.Identity(identity_str=identity_str,
curve_name=curve_name)
class JustInTimeConnection(object):
"""Connect to the device just before the needed operation."""
def __init__(self, conn_factory, identities):
"""Create a JIT connection object."""
self.conn_factory = conn_factory
self.identities = identities
self.public_keys = util.memoize(self._public_keys) # a simple cache
def _public_keys(self):
"""Return a list of SSH public keys (in textual format)."""
conn = self.conn_factory()
return conn.export_public_keys(self.identities)
def parse_public_keys(self):
"""Parse SSH public keys into dictionaries."""
public_keys = [formats.import_public_key(pk)
for pk in self.public_keys()]
for pk, identity in zip(public_keys, self.identities):
pk['identity'] = identity
return public_keys
def sign(self, blob, identity):
"""Sign a given blob using the specified identity on the device."""
conn = self.conn_factory()
return conn.sign_ssh_challenge(blob=blob, identity=identity)
@handle_connection_error
def main(device_type):
"""Run ssh-agent using given hardware client factory."""
args = create_agent_parser().parse_args()
util.setup_logging(verbosity=args.verbose)
if args.identity.startswith('/'):
identities = list(parse_config(fname=args.identity))
else:
identities = [device.interface.Identity(
identity_str=args.identity, curve_name=args.ecdsa_curve_name)]
for index, identity in enumerate(identities):
identity.identity_dict['proto'] = 'ssh'
log.info('identity #%d: %s', index, identity)
if args.connect:
command = ['ssh'] + ssh_args(args.identity) + args.command
elif args.mosh:
command = ['mosh'] + mosh_args(args.identity) + args.command
else:
command = args.command
use_shell = bool(args.shell)
if use_shell:
command = os.environ['SHELL']
conn = JustInTimeConnection(
conn_factory=lambda: client.Client(device_type()),
identities=identities)
if command:
return run_server(conn=conn, command=command, debug=args.debug,
timeout=args.timeout)
else:
for pk in conn.public_keys():
sys.stdout.write(pk)

View File

@@ -0,0 +1 @@
"""Unit-tests for this package."""

View File

@@ -0,0 +1,72 @@
import io
import mock
import pytest
from .. import client, device, formats, util
ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
CURVE = 'nist256p1'
PUBKEY = (b'\x03\xd8(\xb5\xa6`\xbet0\x95\xac:[;]\xdc,\xbd\xdc?\xd7\xc0\xec'
b'\xdd\xbc+\xfar~\x9dAis')
PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
'HAyNTYAAABBBNgotaZgvnQwlaw6Wztd3Cy93D/XwOzdvCv6cn6dQWlzNMEQeW'
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= <localhost:22|nist256p1>\n')
class MockDevice(device.interface.Device): # pylint: disable=abstract-method
def connect(self): # pylint: disable=no-self-use
return mock.Mock()
def pubkey(self, identity, ecdh=False): # pylint: disable=unused-argument
assert self.conn
return PUBKEY
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
assert self.conn
assert blob == BLOB
return SIG
BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
b'\x8e;R\xd3)m\x96\x1b\xb4\xd8s\xf1\x99\x16\xaa2\x00\x00\x00\x05roman'
b'\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey'
b'\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00'
b'\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A'
b'\x04\xd8(\xb5\xa6`\xbet0\x95\xac:[;]\xdc,\xbd\xdc?\xd7\xc0\xec'
b'\xdd\xbc+\xfar~\x9dAis4\xc1\x10yeT~\x1b\xeb\x1aX\xd1\xd9\x9f\xc21'
b'\x13\x8dc\xa7\xa3\x07\xefO\x9e\x95\x0e>\xec\xd8\xaa/')
SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
b'\x862@cx\xb8\xb9i@1\x1b3#\x938\x86]\x97*Y\xb2\x02Xa\xdf@\xecK'
b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2')
def test_ssh_agent():
identity = device.interface.Identity(identity_str='localhost:22',
curve_name=CURVE)
c = client.Client(device=MockDevice())
assert c.export_public_keys([identity]) == [PUBKEY_TEXT]
signature = c.sign_ssh_challenge(blob=BLOB, identity=identity)
key = formats.import_public_key(PUBKEY_TEXT)
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
stream = io.BytesIO(serialized_sig)
r = util.read_frame(stream)
s = util.read_frame(stream)
assert not stream.read()
assert r[:1] == b'\x00'
assert s[:1] == b'\x00'
assert r[1:] + s[1:] == SIG
# pylint: disable=unused-argument
def cancel_sign(identity, blob):
raise IOError(42, 'ERROR')
c.device.sign = cancel_sign
with pytest.raises(IOError):
c.sign_ssh_challenge(blob=BLOB, identity=identity)

View File

@@ -0,0 +1,103 @@
import binascii
import pytest
from .. import formats
def test_fingerprint():
fp = '5d:41:40:2a:bc:4b:2a:76:b9:71:9d:91:10:17:c5:92'
assert formats.fingerprint(b'hello') == fp
_point = (
44423495295951059636974944244307637263954375053872017334547086177777411863925, # nopep8
111713194882028655451852320740440245619792555065469028846314891587105736340201 # nopep8
)
_public_key = (
'ecdsa-sha2-nistp256 '
'AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTY'
'AAABBBGI2zqveJSB+geQEWG46OvGs2h3+0qu7tIdsH8Wylr'
'V19vttd7GR5rKvTWJt8b9ErthmnFALelAFKOB/u50jsuk= '
'home\n'
)
def test_parse_public_key():
key = formats.import_public_key(_public_key)
assert key['name'] == b'home'
assert key['point'] == _point
assert key['curve'] == 'nist256p1'
assert key['fingerprint'] == '4b:19:bc:0f:c8:7e:dc:fa:1a:e3:c2:ff:6f:e0:80:a2' # nopep8
assert key['type'] == b'ecdsa-sha2-nistp256'
def test_decompress():
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
vk = formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
assert formats.export_public_key(vk, label='home') == _public_key
def test_parse_ed25519():
pubkey = ('ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tj'
'fSO8nLIi736is+f0erq28RTc7CkM11NZtTKR hello\n')
p = formats.import_public_key(pubkey)
assert p['name'] == b'hello'
assert p['curve'] == 'ed25519'
BLOB = (b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#'
b'\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14'
b'\xdc\xec)\x0c\xd7SY\xb52\x91')
assert p['blob'] == BLOB
assert p['fingerprint'] == '6b:b0:77:af:e5:3a:21:6d:17:82:9b:06:19:03:a1:97' # nopep8
assert p['type'] == b'ssh-ed25519'
def test_export_ed25519():
pub = (b'\x00P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4'
b'z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91')
vk = formats.decompress_pubkey(pub, formats.CURVE_ED25519)
result = formats.serialize_verifying_key(vk)
assert result == (b'ssh-ed25519',
b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc'
b'\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc'
b'\xec)\x0c\xd7SY\xb52\x91')
def test_decompress_error():
with pytest.raises(ValueError):
formats.decompress_pubkey('', formats.CURVE_NIST256)
def test_curve_mismatch():
# NIST256 public key
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_ED25519)
blob = '00' * 33 # Dummy public key
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
blob = 'FF' * 33 # Unsupported prefix byte
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
def test_serialize_error():
with pytest.raises(TypeError):
formats.serialize_verifying_key(None)
def test_get_ecdh_curve_name():
for c in [formats.CURVE_NIST256, formats.ECDH_CURVE25519]:
assert c == formats.get_ecdh_curve_name(c)
assert (formats.ECDH_CURVE25519 ==
formats.get_ecdh_curve_name(formats.CURVE_ED25519))

View File

@@ -0,0 +1,109 @@
import mock
import pytest
from .. import device, formats, protocol
# pylint: disable=line-too-long
NIST256_KEY = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUksojS/qRlTKBKLQO7CBX7a7oqFkysuFn1nJ6gzlR3wNuQXEgd7qb2bjmiiBHsjNxyWvH5SxVi3+fghrqODWo= ssh://localhost' # nopep8
NIST256_BLOB = b'\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj' # nopep8
NIST256_SIG = b'\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1fq\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
LIST_MSG = b'\x0b'
LIST_NIST256_REPLY = b'\x00\x00\x00\x84\x0c\x00\x00\x00\x01\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x0fssh://localhost' # nopep8
NIST256_SIGN_MSG = b'\r\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\xd1\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x00' # nopep8
NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
def fake_connection(keys, signer):
c = mock.Mock()
c.parse_public_keys.return_value = keys
c.sign = signer
return c
def test_list():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(fake_connection(keys=[key], signer=None))
reply = h.handle(LIST_MSG)
assert reply == LIST_NIST256_REPLY
def test_list_legacy_pubs_with_suffix():
h = protocol.Handler(fake_connection(keys=[], signer=None))
suffix = b'\x00\x00\x00\x06foobar'
reply = h.handle(b'\x01' + suffix)
assert reply == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' # no legacy keys
def test_unsupported():
h = protocol.Handler(fake_connection(keys=[], signer=None))
reply = h.handle(b'\x09')
assert reply == b'\x00\x00\x00\x01\x05'
def ecdsa_signer(identity, blob):
assert str(identity) == '<ssh://localhost|nist256p1>'
assert blob == NIST256_BLOB
return NIST256_SIG
def test_ecdsa_sign():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(fake_connection(keys=[key], signer=ecdsa_signer))
reply = h.handle(NIST256_SIGN_MSG)
assert reply == NIST256_SIGN_REPLY
def test_sign_missing():
h = protocol.Handler(fake_connection(keys=[], signer=ecdsa_signer))
with pytest.raises(KeyError):
h.handle(NIST256_SIGN_MSG)
def test_sign_wrong():
def wrong_signature(identity, blob):
assert str(identity) == '<ssh://localhost|nist256p1>'
assert blob == NIST256_BLOB
return b'\x00' * 64
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(fake_connection(keys=[key], signer=wrong_signature))
with pytest.raises(ValueError):
h.handle(NIST256_SIGN_MSG)
def test_sign_cancel():
def cancel_signature(identity, blob): # pylint: disable=unused-argument
raise IOError()
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(fake_connection(keys=[key], signer=cancel_signature))
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
ED25519_KEY = 'ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tjfSO8nLIi736is+f0erq28RTc7CkM11NZtTKR ssh://localhost' # nopep8
ED25519_SIGN_MSG = b'''\r\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x94\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\xc0\x9f\x97\x0fTC!\x80\x07\x91\xdb^8\xc1\xd62\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x0bssh-ed25519\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x00''' # nopep8
ED25519_SIGN_REPLY = b'''\x00\x00\x00X\x0e\x00\x00\x00S\x00\x00\x00\x0bssh-ed25519\x00\x00\x00@\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\xc0\x9f\x97\x0fTC!\x80\x07\x91\xdb^8\xc1\xd62\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x0bssh-ed25519\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91''' # nopep8
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
def ed25519_signer(identity, blob):
assert str(identity) == '<ssh://localhost|ed25519>'
assert blob == ED25519_BLOB
return ED25519_SIG
def test_ed25519_sign():
key = formats.import_public_key(ED25519_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519')
h = protocol.Handler(fake_connection(keys=[key], signer=ed25519_signer))
reply = h.handle(ED25519_SIGN_MSG)
assert reply == ED25519_SIGN_REPLY

View File

@@ -0,0 +1,140 @@
import functools
import io
import os
import socket
import tempfile
import threading
import mock
import pytest
from .. import protocol, server, util
def test_socket():
path = tempfile.mktemp()
with server.unix_domain_socket_server(path):
pass
assert not os.path.isfile(path)
class FakeSocket(object):
def __init__(self, data=b''):
self.rx = io.BytesIO(data)
self.tx = io.BytesIO()
def sendall(self, data):
self.tx.write(data)
def recv(self, size):
return self.rx.read(size)
def close(self):
pass
def settimeout(self, value):
pass
def empty_device():
c = mock.Mock(spec=['parse_public_keys'])
c.parse_public_keys.return_value = []
return c
def test_handle():
mutex = threading.Lock()
handler = protocol.Handler(conn=empty_device())
conn = FakeSocket()
server.handle_connection(conn, handler, mutex)
msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')])
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler, mutex)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')])
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler, mutex)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')])
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler, mutex)
conn.tx.seek(0)
reply = util.read_frame(conn.tx)
assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE'))
conn_mock = mock.Mock(spec=FakeSocket)
conn_mock.recv.side_effect = [Exception, EOFError]
server.handle_connection(conn=conn_mock, handler=None, mutex=mutex)
def test_server_thread():
connections = [FakeSocket()]
quit_event = threading.Event()
class FakeServer(object):
def accept(self): # pylint: disable=no-self-use
if connections:
return connections.pop(), 'address'
quit_event.set()
raise socket.timeout()
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
handler = protocol.Handler(conn=empty_device()),
handle_conn = functools.partial(server.handle_connection,
handler=handler,
mutex=None)
server.server_thread(sock=FakeServer(),
handle_conn=handle_conn,
quit_event=quit_event)
def test_spawn():
obj = []
def thread(x):
obj.append(x)
with server.spawn(thread, dict(x=1)):
pass
assert obj == [1]
def test_run():
assert server.run_process(['true'], environ={}) == 0
assert server.run_process(['false'], environ={}) == 1
assert server.run_process(command=['bash', '-c', 'exit $X'],
environ={'X': '42'}) == 42
with pytest.raises(OSError):
server.run_process([''], environ={})
def test_serve_main():
handler = protocol.Handler(conn=empty_device())
with server.serve(handler=handler, sock_path=None):
pass
def test_remove():
path = 'foo.bar'
def remove(p):
assert p == path
server.remove_file(path, remove=remove)
def remove_raise(_):
raise OSError('boom')
server.remove_file(path, remove=remove_raise, exists=lambda _: False)
with pytest.raises(OSError):
server.remove_file(path, remove=remove_raise, exists=lambda _: True)

117
libagent/tests/test_util.py Normal file
View File

@@ -0,0 +1,117 @@
import io
import mock
import pytest
from .. import util
def test_bytes2num():
assert util.bytes2num(b'\x12\x34') == 0x1234
def test_num2bytes():
assert util.num2bytes(0x1234, size=2) == b'\x12\x34'
def test_pack():
assert util.pack('BHL', 1, 2, 3) == b'\x01\x00\x02\x00\x00\x00\x03'
def test_frames():
msgs = [b'aaa', b'bb', b'c' * 0x12340]
f = util.frame(*msgs)
assert f == b'\x00\x01\x23\x45' + b''.join(msgs)
assert util.read_frame(io.BytesIO(f)) == b''.join(msgs)
class FakeSocket(object):
def __init__(self):
self.buf = io.BytesIO()
def sendall(self, data):
self.buf.write(data)
def recv(self, size):
return self.buf.read(size)
def test_send_recv():
s = FakeSocket()
util.send(s, b'123')
util.send(s, b'*')
assert s.buf.getvalue() == b'123*'
s.buf.seek(0)
assert util.recv(s, 2) == b'12'
assert util.recv(s, 2) == b'3*'
pytest.raises(EOFError, util.recv, s, 1)
def test_crc24():
assert util.crc24(b'') == b'\xb7\x04\xce'
assert util.crc24(b'1234567890') == b'\x8c\x00\x72'
def test_bit():
assert util.bit(6, 3) == 0
assert util.bit(6, 2) == 1
assert util.bit(6, 1) == 1
assert util.bit(6, 0) == 0
def test_split_bits():
assert util.split_bits(0x1234, 4, 8, 4) == [0x1, 0x23, 0x4]
def test_hexlify():
assert util.hexlify(b'\x12\x34\xab\xcd') == '1234ABCD'
def test_low_bits():
assert util.low_bits(0x1234, 12) == 0x234
assert util.low_bits(0x1234, 32) == 0x1234
assert util.low_bits(0x1234, 0) == 0
def test_readfmt():
stream = io.BytesIO(b'ABC\x12\x34')
assert util.readfmt(stream, 'B') == (65,)
assert util.readfmt(stream, '>2sH') == (b'BC', 0x1234)
def test_prefix_len():
assert util.prefix_len('>H', b'ABCD') == b'\x00\x04ABCD'
def test_reader():
stream = io.BytesIO(b'ABC\x12\x34')
r = util.Reader(stream)
assert r.read(1) == b'A'
assert r.readfmt('2s') == b'BC'
dst = io.BytesIO()
with r.capture(dst):
assert r.readfmt('>H') == 0x1234
assert dst.getvalue() == b'\x12\x34'
with pytest.raises(EOFError):
r.read(1)
def test_setup_logging():
util.setup_logging(verbosity=10)
def test_memoize():
f = mock.Mock(side_effect=lambda x: x)
def func(x):
# mock.Mock doesn't work with functools.wraps()
return f(x)
g = util.memoize(func)
assert g(1) == g(1)
assert g(1) != g(2)
assert f.mock_calls == [mock.call(1), mock.call(2)]

206
libagent/util.py Normal file
View File

@@ -0,0 +1,206 @@
"""Various I/O and serialization utilities."""
import binascii
import contextlib
import functools
import io
import logging
import struct
log = logging.getLogger(__name__)
def send(conn, data):
"""Send data blob to connection socket."""
conn.sendall(data)
def recv(conn, size):
"""
Receive bytes from connection socket or stream.
If size is struct.calcsize()-compatible format, use it to unpack the data.
Otherwise, return the plain blob as bytes.
"""
try:
fmt = size
size = struct.calcsize(fmt)
except TypeError:
fmt = None
try:
_read = conn.recv
except AttributeError:
_read = conn.read
res = io.BytesIO()
while size > 0:
buf = _read(size)
if not buf:
raise EOFError
size = size - len(buf)
res.write(buf)
res = res.getvalue()
if fmt:
return struct.unpack(fmt, res)
else:
return res
def read_frame(conn):
"""Read size-prefixed frame from connection."""
size, = recv(conn, '>L')
return recv(conn, size)
def bytes2num(s):
"""Convert MSB-first bytes to an unsigned integer."""
res = 0
for i, c in enumerate(reversed(bytearray(s))):
res += c << (i * 8)
return res
def num2bytes(value, size):
"""Convert an unsigned integer to MSB-first bytes with specified size."""
res = []
for _ in range(size):
res.append(value & 0xFF)
value = value >> 8
assert value == 0
return bytes(bytearray(list(reversed(res))))
def pack(fmt, *args):
"""Serialize MSB-first message."""
return struct.pack('>' + fmt, *args)
def frame(*msgs):
"""Serialize MSB-first length-prefixed frame."""
res = io.BytesIO()
for msg in msgs:
res.write(msg)
msg = res.getvalue()
return pack('L', len(msg)) + msg
def crc24(blob):
"""See https://tools.ietf.org/html/rfc4880#section-6.1 for details."""
CRC24_INIT = 0x0B704CE
CRC24_POLY = 0x1864CFB
crc = CRC24_INIT
for octet in bytearray(blob):
crc ^= (octet << 16)
for _ in range(8):
crc <<= 1
if crc & 0x1000000:
crc ^= CRC24_POLY
assert 0 <= crc < 0x1000000
crc_bytes = struct.pack('>L', crc)
assert crc_bytes[:1] == b'\x00'
return crc_bytes[1:]
def bit(value, i):
"""Extract the i-th bit out of value."""
return 1 if value & (1 << i) else 0
def low_bits(value, n):
"""Extract the lowest n bits out of value."""
return value & ((1 << n) - 1)
def split_bits(value, *bits):
"""
Split integer value into list of ints, according to `bits` list.
For example, split_bits(0x1234, 4, 8, 4) == [0x1, 0x23, 0x4]
"""
result = []
for b in reversed(bits):
mask = (1 << b) - 1
result.append(value & mask)
value = value >> b
assert value == 0
result.reverse()
return result
def readfmt(stream, fmt):
"""Read and unpack an object from stream, using a struct format string."""
size = struct.calcsize(fmt)
blob = stream.read(size)
return struct.unpack(fmt, blob)
def prefix_len(fmt, blob):
"""Prefix `blob` with its size, serialized using `fmt` format."""
return struct.pack(fmt, len(blob)) + blob
def hexlify(blob):
"""Utility for consistent hexadecimal formatting."""
return binascii.hexlify(blob).decode('ascii').upper()
class Reader(object):
"""Read basic type objects out of given stream."""
def __init__(self, stream):
"""Create a non-capturing reader."""
self.s = stream
self._captured = None
def readfmt(self, fmt):
"""Read a specified object, using a struct format string."""
size = struct.calcsize(fmt)
blob = self.read(size)
obj, = struct.unpack(fmt, blob)
return obj
def read(self, size=None):
"""Read `size` bytes from stream."""
blob = self.s.read(size)
if size is not None and len(blob) < size:
raise EOFError
if self._captured:
self._captured.write(blob)
return blob
@contextlib.contextmanager
def capture(self, stream):
"""Capture all data read during this context."""
self._captured = stream
try:
yield
finally:
self._captured = None
def setup_logging(verbosity, **kwargs):
"""Configure logging for this tool."""
fmt = ('%(asctime)s %(levelname)-12s %(message)-100s '
'[%(filename)s:%(lineno)d]')
levels = [logging.WARNING, logging.INFO, logging.DEBUG]
level = levels[min(verbosity, len(levels) - 1)]
logging.basicConfig(format=fmt, level=level, **kwargs)
def memoize(func):
"""Simple caching decorator."""
cache = {}
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""Caching wrapper."""
key = (args, tuple(sorted(kwargs.items())))
if key in cache:
return cache[key]
else:
result = func(*args, **kwargs)
cache[key] = result
return result
return wrapper