ssh: move related code to a separate subdirectory

This commit is contained in:
Roman Zeyde
2017-05-05 11:22:00 +03:00
parent 6c2273387d
commit 257992d04c
8 changed files with 46 additions and 47 deletions

View File

@@ -1,19 +1,15 @@
"""UNIX-domain socket server for ssh-agent implementation."""
import contextlib
import functools
import logging
import os
import socket
import subprocess
import tempfile
import threading
from . import util
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file(path, remove=os.remove, exists=os.path.exists):
"""Remove file, and raise OSError if still exists."""
@@ -114,39 +110,6 @@ def spawn(func, kwargs):
t.join()
@contextlib.contextmanager
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
if sock_path is None:
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
device_mutex = threading.Lock()
with unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
handle_conn = functools.partial(handle_connection,
handler=handler,
mutex=device_mutex)
kwargs = dict(sock=sock,
handle_conn=handle_conn,
quit_event=quit_event)
with spawn(server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
def run_process(command, environ):
"""
Run the specified process and wait until it finishes.

View File

@@ -1,16 +1,23 @@
"""SSH-agent implementation using hardware authentication devices."""
import argparse
import contextlib
import functools
import logging
import os
import re
import subprocess
import sys
import tempfile
import threading
from .. import client, device, formats, protocol, server, util
from .. import device, formats, server, util
from . import client, protocol
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def ssh_args(label):
"""Create SSH command for connecting specified server."""
@@ -51,7 +58,7 @@ def create_parser():
default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout',
default=server.UNIX_SOCKET_TIMEOUT, type=float,
default=UNIX_SOCKET_TIMEOUT, type=float,
help='Timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true',
help='Log SSH protocol messages for debugging.')
@@ -110,11 +117,44 @@ def git_host(remote_name, attributes):
return '{user}@{host}'.format(**match.groupdict())
@contextlib.contextmanager
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
if sock_path is None:
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
device_mutex = threading.Lock()
with server.unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
handle_conn = functools.partial(server.handle_connection,
handler=handler,
mutex=device_mutex)
kwargs = dict(sock=sock,
handle_conn=handle_conn,
quit_event=quit_event)
with server.spawn(server.server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
def run_server(conn, command, debug, timeout):
"""Common code for run_agent and run_git below."""
try:
handler = protocol.Handler(conn=conn, debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
with serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
except KeyboardInterrupt:
log.info('server stopped')

View File

@@ -0,0 +1 @@
"""Unit-tests for this package."""

View File

@@ -8,7 +8,8 @@ import threading
import mock
import pytest
from .. import protocol, server, util
from .. import server, util
from ..ssh import protocol
def test_socket():
@@ -117,12 +118,6 @@ def test_run():
server.run_process([''], environ={})
def test_serve_main():
handler = protocol.Handler(conn=empty_device())
with server.serve(handler=handler, sock_path=None):
pass
def test_remove():
path = 'foo.bar'