server: improve coverage
This commit is contained in:
@@ -13,14 +13,18 @@ import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_file(path, remove=os.remove, exists=os.path.exists):
|
||||
try:
|
||||
remove(path)
|
||||
except OSError:
|
||||
if exists(path):
|
||||
raise
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unix_domain_socket_server(sock_path):
|
||||
log.debug('serving on SSH_AUTH_SOCK=%s', sock_path)
|
||||
try:
|
||||
os.remove(sock_path)
|
||||
except OSError:
|
||||
if os.path.exists(sock_path):
|
||||
raise
|
||||
remove_file(sock_path)
|
||||
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server.bind(sock_path)
|
||||
@@ -28,7 +32,7 @@ def unix_domain_socket_server(sock_path):
|
||||
try:
|
||||
yield server
|
||||
finally:
|
||||
os.remove(sock_path)
|
||||
remove_file(sock_path)
|
||||
|
||||
|
||||
def handle_connection(conn, handler):
|
||||
|
||||
@@ -47,6 +47,9 @@ def test_handle():
|
||||
server.handle_connection(conn, handler)
|
||||
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
server.handle_connection(conn=None, handler=None)
|
||||
|
||||
|
||||
class ServerMock(object):
|
||||
|
||||
@@ -96,3 +99,20 @@ def test_run():
|
||||
def test_serve_main():
|
||||
with server.serve(public_keys=[], signer=None, sock_path=None):
|
||||
pass
|
||||
|
||||
|
||||
def test_remove():
|
||||
path = 'foo.bar'
|
||||
|
||||
def remove(p):
|
||||
assert p == path
|
||||
|
||||
server.remove_file(path, remove=remove)
|
||||
|
||||
def remove_raise(_):
|
||||
raise OSError('boom')
|
||||
|
||||
server.remove_file(path, remove=remove_raise, exists=lambda _: False)
|
||||
|
||||
with pytest.raises(OSError):
|
||||
server.remove_file(path, remove=remove_raise, exists=lambda _: True)
|
||||
|
||||
Reference in New Issue
Block a user