diff --git a/sshagent/server.py b/sshagent/server.py index 92ea261..6d0f4b6 100644 --- a/sshagent/server.py +++ b/sshagent/server.py @@ -13,14 +13,18 @@ import logging log = logging.getLogger(__name__) +def remove_file(path, remove=os.remove, exists=os.path.exists): + try: + remove(path) + except OSError: + if exists(path): + raise + + @contextlib.contextmanager def unix_domain_socket_server(sock_path): log.debug('serving on SSH_AUTH_SOCK=%s', sock_path) - try: - os.remove(sock_path) - except OSError: - if os.path.exists(sock_path): - raise + remove_file(sock_path) server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server.bind(sock_path) @@ -28,7 +32,7 @@ def unix_domain_socket_server(sock_path): try: yield server finally: - os.remove(sock_path) + remove_file(sock_path) def handle_connection(conn, handler): diff --git a/sshagent/tests/test_server.py b/sshagent/tests/test_server.py index 82f4100..1dae44a 100644 --- a/sshagent/tests/test_server.py +++ b/sshagent/tests/test_server.py @@ -47,6 +47,9 @@ def test_handle(): server.handle_connection(conn, handler) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00' + with pytest.raises(AttributeError): + server.handle_connection(conn=None, handler=None) + class ServerMock(object): @@ -96,3 +99,20 @@ def test_run(): def test_serve_main(): with server.serve(public_keys=[], signer=None, sock_path=None): pass + + +def test_remove(): + path = 'foo.bar' + + def remove(p): + assert p == path + + server.remove_file(path, remove=remove) + + def remove_raise(_): + raise OSError('boom') + + server.remove_file(path, remove=remove_raise, exists=lambda _: False) + + with pytest.raises(OSError): + server.remove_file(path, remove=remove_raise, exists=lambda _: True)