diff --git a/trezor_agent/server.py b/trezor_agent/server.py index 5680614..d4059ac 100644 --- a/trezor_agent/server.py +++ b/trezor_agent/server.py @@ -1,5 +1,6 @@ """UNIX-domain socket server for ssh-agent implementation.""" import contextlib +import functools import logging import os import socket @@ -77,7 +78,7 @@ def retry(func, exception_type, quit_event): pass -def server_thread(sock, handler, quit_event): +def server_thread(sock, handle_conn, quit_event): """Run a server on the specified socket.""" log.debug('server thread started') @@ -94,7 +95,7 @@ def server_thread(sock, handler, quit_event): log.debug('server stopped') break with contextlib.closing(conn): - handle_connection(conn, handler) + handle_conn(conn) log.debug('server thread stopped') @@ -122,7 +123,10 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): with unix_domain_socket_server(sock_path) as sock: sock.settimeout(timeout) quit_event = threading.Event() - kwargs = dict(sock=sock, handler=handler, quit_event=quit_event) + handle_conn = functools.partial(handle_connection, handler=handler) + kwargs = dict(sock=sock, + handle_conn=handle_conn, + quit_event=quit_event) with spawn(server_thread, kwargs): try: yield environ diff --git a/trezor_agent/tests/test_server.py b/trezor_agent/tests/test_server.py index 31487e6..af7f27d 100644 --- a/trezor_agent/tests/test_server.py +++ b/trezor_agent/tests/test_server.py @@ -1,3 +1,4 @@ +import functools import io import os import socket @@ -78,8 +79,10 @@ def test_server_thread(): def getsockname(self): # pylint: disable=no-self-use return 'fake_server' + handler = protocol.Handler(keys=[], signer=None), + handle_conn = functools.partial(server.handle_connection, handler=handler) server.server_thread(sock=FakeServer(), - handler=protocol.Handler(keys=[], signer=None), + handle_conn=handle_conn, quit_event=quit_event)