Split the package into a shared library and separate per-device packages
This commit is contained in:
1
libagent/tests/__init__.py
Normal file
1
libagent/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit-tests for this package."""
|
||||
72
libagent/tests/test_client.py
Normal file
72
libagent/tests/test_client.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import io
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .. import client, device, formats, util
|
||||
|
||||
ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
|
||||
CURVE = 'nist256p1'
|
||||
|
||||
PUBKEY = (b'\x03\xd8(\xb5\xa6`\xbet0\x95\xac:[;]\xdc,\xbd\xdc?\xd7\xc0\xec'
|
||||
b'\xdd\xbc+\xfar~\x9dAis')
|
||||
PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
|
||||
'HAyNTYAAABBBNgotaZgvnQwlaw6Wztd3Cy93D/XwOzdvCv6cn6dQWlzNMEQeW'
|
||||
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= <localhost:22|nist256p1>\n')
|
||||
|
||||
|
||||
class MockDevice(device.interface.Device): # pylint: disable=abstract-method
|
||||
|
||||
def connect(self): # pylint: disable=no-self-use
|
||||
return mock.Mock()
|
||||
|
||||
def pubkey(self, identity, ecdh=False): # pylint: disable=unused-argument
|
||||
assert self.conn
|
||||
return PUBKEY
|
||||
|
||||
def sign(self, identity, blob):
|
||||
"""Sign given blob and return the signature (as bytes)."""
|
||||
assert self.conn
|
||||
assert blob == BLOB
|
||||
return SIG
|
||||
|
||||
|
||||
BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
|
||||
b'\x8e;R\xd3)m\x96\x1b\xb4\xd8s\xf1\x99\x16\xaa2\x00\x00\x00\x05roman'
|
||||
b'\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey'
|
||||
b'\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00'
|
||||
b'\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A'
|
||||
b'\x04\xd8(\xb5\xa6`\xbet0\x95\xac:[;]\xdc,\xbd\xdc?\xd7\xc0\xec'
|
||||
b'\xdd\xbc+\xfar~\x9dAis4\xc1\x10yeT~\x1b\xeb\x1aX\xd1\xd9\x9f\xc21'
|
||||
b'\x13\x8dc\xa7\xa3\x07\xefO\x9e\x95\x0e>\xec\xd8\xaa/')
|
||||
|
||||
SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
|
||||
b'\x862@cx\xb8\xb9i@1\x1b3#\x938\x86]\x97*Y\xb2\x02Xa\xdf@\xecK'
|
||||
b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2')
|
||||
|
||||
|
||||
def test_ssh_agent():
|
||||
identity = device.interface.Identity(identity_str='localhost:22',
|
||||
curve_name=CURVE)
|
||||
c = client.Client(device=MockDevice())
|
||||
assert c.export_public_keys([identity]) == [PUBKEY_TEXT]
|
||||
signature = c.sign_ssh_challenge(blob=BLOB, identity=identity)
|
||||
|
||||
key = formats.import_public_key(PUBKEY_TEXT)
|
||||
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
|
||||
|
||||
stream = io.BytesIO(serialized_sig)
|
||||
r = util.read_frame(stream)
|
||||
s = util.read_frame(stream)
|
||||
assert not stream.read()
|
||||
assert r[:1] == b'\x00'
|
||||
assert s[:1] == b'\x00'
|
||||
assert r[1:] + s[1:] == SIG
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def cancel_sign(identity, blob):
|
||||
raise IOError(42, 'ERROR')
|
||||
|
||||
c.device.sign = cancel_sign
|
||||
with pytest.raises(IOError):
|
||||
c.sign_ssh_challenge(blob=BLOB, identity=identity)
|
||||
103
libagent/tests/test_formats.py
Normal file
103
libagent/tests/test_formats.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import binascii
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import formats
|
||||
|
||||
|
||||
def test_fingerprint():
|
||||
fp = '5d:41:40:2a:bc:4b:2a:76:b9:71:9d:91:10:17:c5:92'
|
||||
assert formats.fingerprint(b'hello') == fp
|
||||
|
||||
|
||||
_point = (
|
||||
44423495295951059636974944244307637263954375053872017334547086177777411863925, # nopep8
|
||||
111713194882028655451852320740440245619792555065469028846314891587105736340201 # nopep8
|
||||
)
|
||||
|
||||
_public_key = (
|
||||
'ecdsa-sha2-nistp256 '
|
||||
'AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTY'
|
||||
'AAABBBGI2zqveJSB+geQEWG46OvGs2h3+0qu7tIdsH8Wylr'
|
||||
'V19vttd7GR5rKvTWJt8b9ErthmnFALelAFKOB/u50jsuk= '
|
||||
'home\n'
|
||||
)
|
||||
|
||||
|
||||
def test_parse_public_key():
|
||||
key = formats.import_public_key(_public_key)
|
||||
assert key['name'] == b'home'
|
||||
assert key['point'] == _point
|
||||
|
||||
assert key['curve'] == 'nist256p1'
|
||||
assert key['fingerprint'] == '4b:19:bc:0f:c8:7e:dc:fa:1a:e3:c2:ff:6f:e0:80:a2' # nopep8
|
||||
assert key['type'] == b'ecdsa-sha2-nistp256'
|
||||
|
||||
|
||||
def test_decompress():
|
||||
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
|
||||
vk = formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_NIST256)
|
||||
assert formats.export_public_key(vk, label='home') == _public_key
|
||||
|
||||
|
||||
def test_parse_ed25519():
|
||||
pubkey = ('ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tj'
|
||||
'fSO8nLIi736is+f0erq28RTc7CkM11NZtTKR hello\n')
|
||||
p = formats.import_public_key(pubkey)
|
||||
assert p['name'] == b'hello'
|
||||
assert p['curve'] == 'ed25519'
|
||||
|
||||
BLOB = (b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#'
|
||||
b'\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14'
|
||||
b'\xdc\xec)\x0c\xd7SY\xb52\x91')
|
||||
assert p['blob'] == BLOB
|
||||
assert p['fingerprint'] == '6b:b0:77:af:e5:3a:21:6d:17:82:9b:06:19:03:a1:97' # nopep8
|
||||
assert p['type'] == b'ssh-ed25519'
|
||||
|
||||
|
||||
def test_export_ed25519():
|
||||
pub = (b'\x00P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4'
|
||||
b'z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91')
|
||||
vk = formats.decompress_pubkey(pub, formats.CURVE_ED25519)
|
||||
result = formats.serialize_verifying_key(vk)
|
||||
assert result == (b'ssh-ed25519',
|
||||
b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc'
|
||||
b'\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc'
|
||||
b'\xec)\x0c\xd7SY\xb52\x91')
|
||||
|
||||
|
||||
def test_decompress_error():
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey('', formats.CURVE_NIST256)
|
||||
|
||||
|
||||
def test_curve_mismatch():
|
||||
# NIST256 public key
|
||||
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_ED25519)
|
||||
|
||||
blob = '00' * 33 # Dummy public key
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_NIST256)
|
||||
|
||||
blob = 'FF' * 33 # Unsupported prefix byte
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_NIST256)
|
||||
|
||||
|
||||
def test_serialize_error():
|
||||
with pytest.raises(TypeError):
|
||||
formats.serialize_verifying_key(None)
|
||||
|
||||
|
||||
def test_get_ecdh_curve_name():
|
||||
for c in [formats.CURVE_NIST256, formats.ECDH_CURVE25519]:
|
||||
assert c == formats.get_ecdh_curve_name(c)
|
||||
|
||||
assert (formats.ECDH_CURVE25519 ==
|
||||
formats.get_ecdh_curve_name(formats.CURVE_ED25519))
|
||||
109
libagent/tests/test_protocol.py
Normal file
109
libagent/tests/test_protocol.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .. import device, formats, protocol
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
NIST256_KEY = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUksojS/qRlTKBKLQO7CBX7a7oqFkysuFn1nJ6gzlR3wNuQXEgd7qb2bjmiiBHsjNxyWvH5SxVi3+fghrqODWo= ssh://localhost' # nopep8
|
||||
NIST256_BLOB = b'\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj' # nopep8
|
||||
NIST256_SIG = b'\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1fq\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
|
||||
|
||||
LIST_MSG = b'\x0b'
|
||||
LIST_NIST256_REPLY = b'\x00\x00\x00\x84\x0c\x00\x00\x00\x01\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x0fssh://localhost' # nopep8
|
||||
|
||||
NIST256_SIGN_MSG = b'\r\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\xd1\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x00' # nopep8
|
||||
NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
|
||||
|
||||
|
||||
def fake_connection(keys, signer):
|
||||
c = mock.Mock()
|
||||
c.parse_public_keys.return_value = keys
|
||||
c.sign = signer
|
||||
return c
|
||||
|
||||
|
||||
def test_list():
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=None))
|
||||
reply = h.handle(LIST_MSG)
|
||||
assert reply == LIST_NIST256_REPLY
|
||||
|
||||
|
||||
def test_list_legacy_pubs_with_suffix():
|
||||
h = protocol.Handler(fake_connection(keys=[], signer=None))
|
||||
suffix = b'\x00\x00\x00\x06foobar'
|
||||
reply = h.handle(b'\x01' + suffix)
|
||||
assert reply == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' # no legacy keys
|
||||
|
||||
|
||||
def test_unsupported():
|
||||
h = protocol.Handler(fake_connection(keys=[], signer=None))
|
||||
reply = h.handle(b'\x09')
|
||||
assert reply == b'\x00\x00\x00\x01\x05'
|
||||
|
||||
|
||||
def ecdsa_signer(identity, blob):
|
||||
assert str(identity) == '<ssh://localhost|nist256p1>'
|
||||
assert blob == NIST256_BLOB
|
||||
return NIST256_SIG
|
||||
|
||||
|
||||
def test_ecdsa_sign():
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=ecdsa_signer))
|
||||
reply = h.handle(NIST256_SIGN_MSG)
|
||||
assert reply == NIST256_SIGN_REPLY
|
||||
|
||||
|
||||
def test_sign_missing():
|
||||
h = protocol.Handler(fake_connection(keys=[], signer=ecdsa_signer))
|
||||
with pytest.raises(KeyError):
|
||||
h.handle(NIST256_SIGN_MSG)
|
||||
|
||||
|
||||
def test_sign_wrong():
|
||||
def wrong_signature(identity, blob):
|
||||
assert str(identity) == '<ssh://localhost|nist256p1>'
|
||||
assert blob == NIST256_BLOB
|
||||
return b'\x00' * 64
|
||||
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=wrong_signature))
|
||||
with pytest.raises(ValueError):
|
||||
h.handle(NIST256_SIGN_MSG)
|
||||
|
||||
|
||||
def test_sign_cancel():
|
||||
def cancel_signature(identity, blob): # pylint: disable=unused-argument
|
||||
raise IOError()
|
||||
|
||||
key = formats.import_public_key(NIST256_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=cancel_signature))
|
||||
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
|
||||
|
||||
|
||||
ED25519_KEY = 'ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tjfSO8nLIi736is+f0erq28RTc7CkM11NZtTKR ssh://localhost' # nopep8
|
||||
ED25519_SIGN_MSG = b'''\r\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x94\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\xc0\x9f\x97\x0fTC!\x80\x07\x91\xdb^8\xc1\xd62\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x0bssh-ed25519\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x00''' # nopep8
|
||||
ED25519_SIGN_REPLY = b'''\x00\x00\x00X\x0e\x00\x00\x00S\x00\x00\x00\x0bssh-ed25519\x00\x00\x00@\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
|
||||
|
||||
ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\xc0\x9f\x97\x0fTC!\x80\x07\x91\xdb^8\xc1\xd62\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x0bssh-ed25519\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91''' # nopep8
|
||||
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
|
||||
|
||||
|
||||
def ed25519_signer(identity, blob):
|
||||
assert str(identity) == '<ssh://localhost|ed25519>'
|
||||
assert blob == ED25519_BLOB
|
||||
return ED25519_SIG
|
||||
|
||||
|
||||
def test_ed25519_sign():
|
||||
key = formats.import_public_key(ED25519_KEY)
|
||||
key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519')
|
||||
h = protocol.Handler(fake_connection(keys=[key], signer=ed25519_signer))
|
||||
reply = h.handle(ED25519_SIGN_MSG)
|
||||
assert reply == ED25519_SIGN_REPLY
|
||||
140
libagent/tests/test_server.py
Normal file
140
libagent/tests/test_server.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .. import protocol, server, util
|
||||
|
||||
|
||||
def test_socket():
|
||||
path = tempfile.mktemp()
|
||||
with server.unix_domain_socket_server(path):
|
||||
pass
|
||||
assert not os.path.isfile(path)
|
||||
|
||||
|
||||
class FakeSocket(object):
|
||||
|
||||
def __init__(self, data=b''):
|
||||
self.rx = io.BytesIO(data)
|
||||
self.tx = io.BytesIO()
|
||||
|
||||
def sendall(self, data):
|
||||
self.tx.write(data)
|
||||
|
||||
def recv(self, size):
|
||||
return self.rx.read(size)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def settimeout(self, value):
|
||||
pass
|
||||
|
||||
|
||||
def empty_device():
|
||||
c = mock.Mock(spec=['parse_public_keys'])
|
||||
c.parse_public_keys.return_value = []
|
||||
return c
|
||||
|
||||
|
||||
def test_handle():
|
||||
mutex = threading.Lock()
|
||||
|
||||
handler = protocol.Handler(conn=empty_device())
|
||||
conn = FakeSocket()
|
||||
server.handle_connection(conn, handler, mutex)
|
||||
|
||||
msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')])
|
||||
conn = FakeSocket(util.frame(msg))
|
||||
server.handle_connection(conn, handler, mutex)
|
||||
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
|
||||
|
||||
msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')])
|
||||
conn = FakeSocket(util.frame(msg))
|
||||
server.handle_connection(conn, handler, mutex)
|
||||
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
|
||||
|
||||
msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')])
|
||||
conn = FakeSocket(util.frame(msg))
|
||||
server.handle_connection(conn, handler, mutex)
|
||||
conn.tx.seek(0)
|
||||
reply = util.read_frame(conn.tx)
|
||||
assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE'))
|
||||
|
||||
conn_mock = mock.Mock(spec=FakeSocket)
|
||||
conn_mock.recv.side_effect = [Exception, EOFError]
|
||||
server.handle_connection(conn=conn_mock, handler=None, mutex=mutex)
|
||||
|
||||
|
||||
def test_server_thread():
|
||||
connections = [FakeSocket()]
|
||||
quit_event = threading.Event()
|
||||
|
||||
class FakeServer(object):
|
||||
def accept(self): # pylint: disable=no-self-use
|
||||
if connections:
|
||||
return connections.pop(), 'address'
|
||||
quit_event.set()
|
||||
raise socket.timeout()
|
||||
|
||||
def getsockname(self): # pylint: disable=no-self-use
|
||||
return 'fake_server'
|
||||
|
||||
handler = protocol.Handler(conn=empty_device()),
|
||||
handle_conn = functools.partial(server.handle_connection,
|
||||
handler=handler,
|
||||
mutex=None)
|
||||
server.server_thread(sock=FakeServer(),
|
||||
handle_conn=handle_conn,
|
||||
quit_event=quit_event)
|
||||
|
||||
|
||||
def test_spawn():
|
||||
obj = []
|
||||
|
||||
def thread(x):
|
||||
obj.append(x)
|
||||
|
||||
with server.spawn(thread, dict(x=1)):
|
||||
pass
|
||||
|
||||
assert obj == [1]
|
||||
|
||||
|
||||
def test_run():
|
||||
assert server.run_process(['true'], environ={}) == 0
|
||||
assert server.run_process(['false'], environ={}) == 1
|
||||
assert server.run_process(command=['bash', '-c', 'exit $X'],
|
||||
environ={'X': '42'}) == 42
|
||||
|
||||
with pytest.raises(OSError):
|
||||
server.run_process([''], environ={})
|
||||
|
||||
|
||||
def test_serve_main():
|
||||
handler = protocol.Handler(conn=empty_device())
|
||||
with server.serve(handler=handler, sock_path=None):
|
||||
pass
|
||||
|
||||
|
||||
def test_remove():
|
||||
path = 'foo.bar'
|
||||
|
||||
def remove(p):
|
||||
assert p == path
|
||||
|
||||
server.remove_file(path, remove=remove)
|
||||
|
||||
def remove_raise(_):
|
||||
raise OSError('boom')
|
||||
|
||||
server.remove_file(path, remove=remove_raise, exists=lambda _: False)
|
||||
|
||||
with pytest.raises(OSError):
|
||||
server.remove_file(path, remove=remove_raise, exists=lambda _: True)
|
||||
117
libagent/tests/test_util.py
Normal file
117
libagent/tests/test_util.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import io
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .. import util
|
||||
|
||||
|
||||
def test_bytes2num():
|
||||
assert util.bytes2num(b'\x12\x34') == 0x1234
|
||||
|
||||
|
||||
def test_num2bytes():
|
||||
assert util.num2bytes(0x1234, size=2) == b'\x12\x34'
|
||||
|
||||
|
||||
def test_pack():
|
||||
assert util.pack('BHL', 1, 2, 3) == b'\x01\x00\x02\x00\x00\x00\x03'
|
||||
|
||||
|
||||
def test_frames():
|
||||
msgs = [b'aaa', b'bb', b'c' * 0x12340]
|
||||
f = util.frame(*msgs)
|
||||
assert f == b'\x00\x01\x23\x45' + b''.join(msgs)
|
||||
assert util.read_frame(io.BytesIO(f)) == b''.join(msgs)
|
||||
|
||||
|
||||
class FakeSocket(object):
|
||||
def __init__(self):
|
||||
self.buf = io.BytesIO()
|
||||
|
||||
def sendall(self, data):
|
||||
self.buf.write(data)
|
||||
|
||||
def recv(self, size):
|
||||
return self.buf.read(size)
|
||||
|
||||
|
||||
def test_send_recv():
|
||||
s = FakeSocket()
|
||||
util.send(s, b'123')
|
||||
util.send(s, b'*')
|
||||
assert s.buf.getvalue() == b'123*'
|
||||
|
||||
s.buf.seek(0)
|
||||
assert util.recv(s, 2) == b'12'
|
||||
assert util.recv(s, 2) == b'3*'
|
||||
|
||||
pytest.raises(EOFError, util.recv, s, 1)
|
||||
|
||||
|
||||
def test_crc24():
|
||||
assert util.crc24(b'') == b'\xb7\x04\xce'
|
||||
assert util.crc24(b'1234567890') == b'\x8c\x00\x72'
|
||||
|
||||
|
||||
def test_bit():
|
||||
assert util.bit(6, 3) == 0
|
||||
assert util.bit(6, 2) == 1
|
||||
assert util.bit(6, 1) == 1
|
||||
assert util.bit(6, 0) == 0
|
||||
|
||||
|
||||
def test_split_bits():
|
||||
assert util.split_bits(0x1234, 4, 8, 4) == [0x1, 0x23, 0x4]
|
||||
|
||||
|
||||
def test_hexlify():
|
||||
assert util.hexlify(b'\x12\x34\xab\xcd') == '1234ABCD'
|
||||
|
||||
|
||||
def test_low_bits():
|
||||
assert util.low_bits(0x1234, 12) == 0x234
|
||||
assert util.low_bits(0x1234, 32) == 0x1234
|
||||
assert util.low_bits(0x1234, 0) == 0
|
||||
|
||||
|
||||
def test_readfmt():
|
||||
stream = io.BytesIO(b'ABC\x12\x34')
|
||||
assert util.readfmt(stream, 'B') == (65,)
|
||||
assert util.readfmt(stream, '>2sH') == (b'BC', 0x1234)
|
||||
|
||||
|
||||
def test_prefix_len():
|
||||
assert util.prefix_len('>H', b'ABCD') == b'\x00\x04ABCD'
|
||||
|
||||
|
||||
def test_reader():
|
||||
stream = io.BytesIO(b'ABC\x12\x34')
|
||||
r = util.Reader(stream)
|
||||
assert r.read(1) == b'A'
|
||||
assert r.readfmt('2s') == b'BC'
|
||||
|
||||
dst = io.BytesIO()
|
||||
with r.capture(dst):
|
||||
assert r.readfmt('>H') == 0x1234
|
||||
assert dst.getvalue() == b'\x12\x34'
|
||||
|
||||
with pytest.raises(EOFError):
|
||||
r.read(1)
|
||||
|
||||
|
||||
def test_setup_logging():
|
||||
util.setup_logging(verbosity=10)
|
||||
|
||||
|
||||
def test_memoize():
|
||||
f = mock.Mock(side_effect=lambda x: x)
|
||||
|
||||
def func(x):
|
||||
# mock.Mock doesn't work with functools.wraps()
|
||||
return f(x)
|
||||
|
||||
g = util.memoize(func)
|
||||
assert g(1) == g(1)
|
||||
assert g(1) != g(2)
|
||||
assert f.mock_calls == [mock.call(1), mock.call(2)]
|
||||
Reference in New Issue
Block a user