server: pass handler and add debug option
This commit is contained in:
@@ -6,7 +6,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from . import formats, server, trezor
|
||||
from . import formats, protocol, server, trezor
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,6 +71,8 @@ def create_agent_parser():
|
||||
p.add_argument('--timeout',
|
||||
default=server.UNIX_SOCKET_TIMEOUT, type=float,
|
||||
help='Timeout for accepting SSH client connections')
|
||||
p.add_argument('--debug', default=False, action='store_true',
|
||||
help='Log SSH protocol messages for debugging.')
|
||||
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
|
||||
help='command to run under the SSH agent')
|
||||
return p
|
||||
@@ -119,9 +121,10 @@ def run_agent(client_factory):
|
||||
|
||||
try:
|
||||
signer = functools.partial(ssh_sign, client=client)
|
||||
with server.serve(public_keys=[public_key],
|
||||
signer=signer,
|
||||
timeout=args.timeout) as env:
|
||||
public_keys = [formats.import_public_key(public_key)]
|
||||
handler = protocol.Handler(keys=public_keys, signer=signer,
|
||||
debug=args.debug)
|
||||
with server.serve(handler=handler, timeout=args.timeout) as env:
|
||||
return server.run_process(command=command,
|
||||
environ=env,
|
||||
use_shell=use_shell)
|
||||
|
||||
@@ -6,7 +6,7 @@ import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
from . import formats, protocol, util
|
||||
from . import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -87,15 +87,13 @@ def spawn(func, kwargs):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def serve(public_keys, signer, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
if sock_path is None:
|
||||
sock_path = tempfile.mktemp(prefix='ssh-agent-')
|
||||
|
||||
keys = [formats.import_public_key(k) for k in public_keys]
|
||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
with unix_domain_socket_server(sock_path) as server:
|
||||
server.settimeout(timeout)
|
||||
handler = protocol.Handler(keys=keys, signer=signer)
|
||||
quit_event = threading.Event()
|
||||
kwargs = dict(server=server, handler=handler, quit_event=quit_event)
|
||||
with spawn(server_thread, kwargs):
|
||||
|
||||
@@ -101,7 +101,8 @@ def test_run():
|
||||
|
||||
|
||||
def test_serve_main():
|
||||
with server.serve(public_keys=[], signer=None, sock_path=None):
|
||||
handler = protocol.Handler(keys=[], signer=None)
|
||||
with server.serve(handler=handler, sock_path=None):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user