ssh: move related code to a separate subdirectory
This commit is contained in:
@@ -1,19 +1,15 @@
|
||||
"""UNIX-domain socket server for ssh-agent implementation."""
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
from . import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
UNIX_SOCKET_TIMEOUT = 0.1
|
||||
|
||||
|
||||
def remove_file(path, remove=os.remove, exists=os.path.exists):
|
||||
"""Remove file, and raise OSError if still exists."""
|
||||
@@ -114,39 +110,6 @@ def spawn(func, kwargs):
|
||||
t.join()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
"""
|
||||
Start the ssh-agent server on a UNIX-domain socket.
|
||||
|
||||
If no connection is made during the specified timeout,
|
||||
retry until the context is over.
|
||||
"""
|
||||
ssh_version = subprocess.check_output(['ssh', '-V'],
|
||||
stderr=subprocess.STDOUT)
|
||||
log.debug('local SSH version: %r', ssh_version)
|
||||
if sock_path is None:
|
||||
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
||||
|
||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
device_mutex = threading.Lock()
|
||||
with unix_domain_socket_server(sock_path) as sock:
|
||||
sock.settimeout(timeout)
|
||||
quit_event = threading.Event()
|
||||
handle_conn = functools.partial(handle_connection,
|
||||
handler=handler,
|
||||
mutex=device_mutex)
|
||||
kwargs = dict(sock=sock,
|
||||
handle_conn=handle_conn,
|
||||
quit_event=quit_event)
|
||||
with spawn(server_thread, kwargs):
|
||||
try:
|
||||
yield environ
|
||||
finally:
|
||||
log.debug('closing server')
|
||||
quit_event.set()
|
||||
|
||||
|
||||
def run_process(command, environ):
|
||||
"""
|
||||
Run the specified process and wait until it finishes.
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""SSH-agent implementation using hardware authentication devices."""
|
||||
import argparse
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
from .. import client, device, formats, protocol, server, util
|
||||
|
||||
from .. import device, formats, server, util
|
||||
from . import client, protocol
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
UNIX_SOCKET_TIMEOUT = 0.1
|
||||
|
||||
|
||||
def ssh_args(label):
|
||||
"""Create SSH command for connecting specified server."""
|
||||
@@ -51,7 +58,7 @@ def create_parser():
|
||||
default=formats.CURVE_NIST256,
|
||||
help='specify ECDSA curve name: ' + curve_names)
|
||||
p.add_argument('--timeout',
|
||||
default=server.UNIX_SOCKET_TIMEOUT, type=float,
|
||||
default=UNIX_SOCKET_TIMEOUT, type=float,
|
||||
help='Timeout for accepting SSH client connections')
|
||||
p.add_argument('--debug', default=False, action='store_true',
|
||||
help='Log SSH protocol messages for debugging.')
|
||||
@@ -110,11 +117,44 @@ def git_host(remote_name, attributes):
|
||||
return '{user}@{host}'.format(**match.groupdict())
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
"""
|
||||
Start the ssh-agent server on a UNIX-domain socket.
|
||||
|
||||
If no connection is made during the specified timeout,
|
||||
retry until the context is over.
|
||||
"""
|
||||
ssh_version = subprocess.check_output(['ssh', '-V'],
|
||||
stderr=subprocess.STDOUT)
|
||||
log.debug('local SSH version: %r', ssh_version)
|
||||
if sock_path is None:
|
||||
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
||||
|
||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
device_mutex = threading.Lock()
|
||||
with server.unix_domain_socket_server(sock_path) as sock:
|
||||
sock.settimeout(timeout)
|
||||
quit_event = threading.Event()
|
||||
handle_conn = functools.partial(server.handle_connection,
|
||||
handler=handler,
|
||||
mutex=device_mutex)
|
||||
kwargs = dict(sock=sock,
|
||||
handle_conn=handle_conn,
|
||||
quit_event=quit_event)
|
||||
with server.spawn(server.server_thread, kwargs):
|
||||
try:
|
||||
yield environ
|
||||
finally:
|
||||
log.debug('closing server')
|
||||
quit_event.set()
|
||||
|
||||
|
||||
def run_server(conn, command, debug, timeout):
|
||||
"""Common code for run_agent and run_git below."""
|
||||
try:
|
||||
handler = protocol.Handler(conn=conn, debug=debug)
|
||||
with server.serve(handler=handler, timeout=timeout) as env:
|
||||
with serve(handler=handler, timeout=timeout) as env:
|
||||
return server.run_process(command=command, environ=env)
|
||||
except KeyboardInterrupt:
|
||||
log.info('server stopped')
|
||||
|
||||
1
libagent/ssh/tests/__init__.py
Normal file
1
libagent/ssh/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit-tests for this package."""
|
||||
@@ -8,7 +8,8 @@ import threading
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .. import protocol, server, util
|
||||
from .. import server, util
|
||||
from ..ssh import protocol
|
||||
|
||||
|
||||
def test_socket():
|
||||
@@ -117,12 +118,6 @@ def test_run():
|
||||
server.run_process([''], environ={})
|
||||
|
||||
|
||||
def test_serve_main():
|
||||
handler = protocol.Handler(conn=empty_device())
|
||||
with server.serve(handler=handler, sock_path=None):
|
||||
pass
|
||||
|
||||
|
||||
def test_remove():
|
||||
path = 'foo.bar'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user