server: stop the server via a threading.Event
It seems that Mac OS does not support calling socket.shutdown(socket.SHUT_RD) on a listening socket (see https://github.com/romanz/trezor-agent/issues/6). The following implementation will set the accept() timeout to 0.1s and stop the server if a threading.Event (named "quit_event") is set by the main thread.
This commit is contained in:
1
tox.ini
1
tox.ini
@@ -12,3 +12,4 @@ commands=
|
||||
pylint --reports=no --rcfile .pylintrc trezor_agent
|
||||
coverage run --omit='trezor_agent/__main__.py' --source trezor_agent -m py.test -v trezor_agent
|
||||
coverage report
|
||||
coverage html
|
||||
|
||||
@@ -12,6 +12,8 @@ from . import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
UNIX_SOCKET_TIMEOUT = 0.1
|
||||
|
||||
|
||||
def remove_file(path, remove=os.remove, exists=os.path.exists):
|
||||
try:
|
||||
@@ -44,19 +46,31 @@ def handle_connection(conn, handler):
|
||||
util.send(conn, reply)
|
||||
except EOFError:
|
||||
log.debug('goodbye agent')
|
||||
except:
|
||||
log.exception('error')
|
||||
raise
|
||||
|
||||
|
||||
def server_thread(server, handler):
|
||||
def retry(func, exception_type, quit_event):
|
||||
while True:
|
||||
if quit_event.is_set():
|
||||
raise StopIteration
|
||||
try:
|
||||
return func()
|
||||
except exception_type:
|
||||
pass
|
||||
|
||||
|
||||
def server_thread(server, handler, quit_event):
|
||||
log.debug('server thread started')
|
||||
|
||||
def accept_connection():
|
||||
conn, _ = server.accept()
|
||||
return conn
|
||||
|
||||
while True:
|
||||
log.debug('waiting for connection on %s', server.getsockname())
|
||||
try:
|
||||
conn, _ = server.accept()
|
||||
except socket.error as e:
|
||||
log.debug('server stopped: %s', e)
|
||||
conn = retry(accept_connection, socket.timeout, quit_event)
|
||||
except StopIteration:
|
||||
log.debug('server stopped')
|
||||
break
|
||||
with contextlib.closing(conn):
|
||||
handle_connection(conn, handler)
|
||||
@@ -64,7 +78,7 @@ def server_thread(server, handler):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def spawn(func, **kwargs):
|
||||
def spawn(func, kwargs):
|
||||
t = threading.Thread(target=func, kwargs=kwargs)
|
||||
t.start()
|
||||
yield
|
||||
@@ -72,20 +86,23 @@ def spawn(func, **kwargs):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def serve(public_keys, signer, sock_path=None):
|
||||
def serve(public_keys, signer, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
if sock_path is None:
|
||||
sock_path = tempfile.mktemp(prefix='ssh-agent-')
|
||||
|
||||
keys = [formats.import_public_key(k) for k in public_keys]
|
||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
with unix_domain_socket_server(sock_path) as server:
|
||||
server.settimeout(timeout)
|
||||
handler = protocol.Handler(keys=keys, signer=signer)
|
||||
with spawn(server_thread, server=server, handler=handler):
|
||||
quit_event = threading.Event()
|
||||
kwargs = dict(server=server, handler=handler, quit_event=quit_event)
|
||||
with spawn(server_thread, kwargs):
|
||||
try:
|
||||
yield environ
|
||||
finally:
|
||||
log.debug('closing server')
|
||||
server.shutdown(socket.SHUT_RD)
|
||||
quit_event.set()
|
||||
|
||||
|
||||
def run_process(command, environ, use_shell=False):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import tempfile
|
||||
import socket
|
||||
import threading
|
||||
import os
|
||||
import io
|
||||
import pytest
|
||||
@@ -16,7 +17,7 @@ def test_socket():
|
||||
assert not os.path.isfile(path)
|
||||
|
||||
|
||||
class SocketMock(object):
|
||||
class FakeSocket(object):
|
||||
|
||||
def __init__(self, data=b''):
|
||||
self.rx = io.BytesIO(data)
|
||||
@@ -34,16 +35,16 @@ class SocketMock(object):
|
||||
|
||||
def test_handle():
|
||||
handler = protocol.Handler(keys=[], signer=None)
|
||||
conn = SocketMock()
|
||||
conn = FakeSocket()
|
||||
server.handle_connection(conn, handler)
|
||||
|
||||
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
|
||||
conn = SocketMock(util.frame(msg))
|
||||
conn = FakeSocket(util.frame(msg))
|
||||
server.handle_connection(conn, handler)
|
||||
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
|
||||
|
||||
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
|
||||
conn = SocketMock(util.frame(msg))
|
||||
conn = FakeSocket(util.frame(msg))
|
||||
server.handle_connection(conn, handler)
|
||||
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
|
||||
|
||||
@@ -51,25 +52,24 @@ def test_handle():
|
||||
server.handle_connection(conn=None, handler=None)
|
||||
|
||||
|
||||
class ServerMock(object):
|
||||
|
||||
def __init__(self, connections, name):
|
||||
self.connections = connections
|
||||
self.name = name
|
||||
|
||||
def getsockname(self):
|
||||
return self.name
|
||||
|
||||
def accept(self):
|
||||
if self.connections:
|
||||
return self.connections.pop(), 'address'
|
||||
raise socket.error('stop')
|
||||
|
||||
|
||||
def test_server_thread():
|
||||
s = ServerMock(connections=[SocketMock()], name='mock')
|
||||
h = protocol.Handler(keys=[], signer=None)
|
||||
server.server_thread(s, h)
|
||||
|
||||
connections = [FakeSocket()]
|
||||
quit_event = threading.Event()
|
||||
|
||||
class FakeServer(object):
|
||||
def accept(self): # pylint: disable=no-self-use
|
||||
if connections:
|
||||
return connections.pop(), 'address'
|
||||
quit_event.set()
|
||||
raise socket.timeout()
|
||||
|
||||
def getsockname(self): # pylint: disable=no-self-use
|
||||
return 'fake_server'
|
||||
|
||||
server.server_thread(server=FakeServer(),
|
||||
handler=protocol.Handler(keys=[], signer=None),
|
||||
quit_event=quit_event)
|
||||
|
||||
|
||||
def test_spawn():
|
||||
@@ -78,7 +78,7 @@ def test_spawn():
|
||||
def thread(x):
|
||||
obj.append(x)
|
||||
|
||||
with server.spawn(thread, x=1):
|
||||
with server.spawn(thread, dict(x=1)):
|
||||
pass
|
||||
|
||||
assert obj == [1]
|
||||
|
||||
Reference in New Issue
Block a user