ssh: move related code to a separate subdirectory
This commit is contained in:
@@ -1,19 +1,15 @@
|
|||||||
"""UNIX-domain socket server for ssh-agent implementation."""
|
"""UNIX-domain socket server for ssh-agent implementation."""
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from . import util
|
from . import util
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
UNIX_SOCKET_TIMEOUT = 0.1
|
|
||||||
|
|
||||||
|
|
||||||
def remove_file(path, remove=os.remove, exists=os.path.exists):
|
def remove_file(path, remove=os.remove, exists=os.path.exists):
|
||||||
"""Remove file, and raise OSError if still exists."""
|
"""Remove file, and raise OSError if still exists."""
|
||||||
@@ -114,39 +110,6 @@ def spawn(func, kwargs):
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
|
||||||
"""
|
|
||||||
Start the ssh-agent server on a UNIX-domain socket.
|
|
||||||
|
|
||||||
If no connection is made during the specified timeout,
|
|
||||||
retry until the context is over.
|
|
||||||
"""
|
|
||||||
ssh_version = subprocess.check_output(['ssh', '-V'],
|
|
||||||
stderr=subprocess.STDOUT)
|
|
||||||
log.debug('local SSH version: %r', ssh_version)
|
|
||||||
if sock_path is None:
|
|
||||||
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
|
||||||
|
|
||||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
|
||||||
device_mutex = threading.Lock()
|
|
||||||
with unix_domain_socket_server(sock_path) as sock:
|
|
||||||
sock.settimeout(timeout)
|
|
||||||
quit_event = threading.Event()
|
|
||||||
handle_conn = functools.partial(handle_connection,
|
|
||||||
handler=handler,
|
|
||||||
mutex=device_mutex)
|
|
||||||
kwargs = dict(sock=sock,
|
|
||||||
handle_conn=handle_conn,
|
|
||||||
quit_event=quit_event)
|
|
||||||
with spawn(server_thread, kwargs):
|
|
||||||
try:
|
|
||||||
yield environ
|
|
||||||
finally:
|
|
||||||
log.debug('closing server')
|
|
||||||
quit_event.set()
|
|
||||||
|
|
||||||
|
|
||||||
def run_process(command, environ):
|
def run_process(command, environ):
|
||||||
"""
|
"""
|
||||||
Run the specified process and wait until it finishes.
|
Run the specified process and wait until it finishes.
|
||||||
|
|||||||
@@ -1,16 +1,23 @@
|
|||||||
"""SSH-agent implementation using hardware authentication devices."""
|
"""SSH-agent implementation using hardware authentication devices."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
|
||||||
from .. import client, device, formats, protocol, server, util
|
|
||||||
|
from .. import device, formats, server, util
|
||||||
|
from . import client, protocol
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
UNIX_SOCKET_TIMEOUT = 0.1
|
||||||
|
|
||||||
|
|
||||||
def ssh_args(label):
|
def ssh_args(label):
|
||||||
"""Create SSH command for connecting specified server."""
|
"""Create SSH command for connecting specified server."""
|
||||||
@@ -51,7 +58,7 @@ def create_parser():
|
|||||||
default=formats.CURVE_NIST256,
|
default=formats.CURVE_NIST256,
|
||||||
help='specify ECDSA curve name: ' + curve_names)
|
help='specify ECDSA curve name: ' + curve_names)
|
||||||
p.add_argument('--timeout',
|
p.add_argument('--timeout',
|
||||||
default=server.UNIX_SOCKET_TIMEOUT, type=float,
|
default=UNIX_SOCKET_TIMEOUT, type=float,
|
||||||
help='Timeout for accepting SSH client connections')
|
help='Timeout for accepting SSH client connections')
|
||||||
p.add_argument('--debug', default=False, action='store_true',
|
p.add_argument('--debug', default=False, action='store_true',
|
||||||
help='Log SSH protocol messages for debugging.')
|
help='Log SSH protocol messages for debugging.')
|
||||||
@@ -110,11 +117,44 @@ def git_host(remote_name, attributes):
|
|||||||
return '{user}@{host}'.format(**match.groupdict())
|
return '{user}@{host}'.format(**match.groupdict())
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
||||||
|
"""
|
||||||
|
Start the ssh-agent server on a UNIX-domain socket.
|
||||||
|
|
||||||
|
If no connection is made during the specified timeout,
|
||||||
|
retry until the context is over.
|
||||||
|
"""
|
||||||
|
ssh_version = subprocess.check_output(['ssh', '-V'],
|
||||||
|
stderr=subprocess.STDOUT)
|
||||||
|
log.debug('local SSH version: %r', ssh_version)
|
||||||
|
if sock_path is None:
|
||||||
|
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
||||||
|
|
||||||
|
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||||
|
device_mutex = threading.Lock()
|
||||||
|
with server.unix_domain_socket_server(sock_path) as sock:
|
||||||
|
sock.settimeout(timeout)
|
||||||
|
quit_event = threading.Event()
|
||||||
|
handle_conn = functools.partial(server.handle_connection,
|
||||||
|
handler=handler,
|
||||||
|
mutex=device_mutex)
|
||||||
|
kwargs = dict(sock=sock,
|
||||||
|
handle_conn=handle_conn,
|
||||||
|
quit_event=quit_event)
|
||||||
|
with server.spawn(server.server_thread, kwargs):
|
||||||
|
try:
|
||||||
|
yield environ
|
||||||
|
finally:
|
||||||
|
log.debug('closing server')
|
||||||
|
quit_event.set()
|
||||||
|
|
||||||
|
|
||||||
def run_server(conn, command, debug, timeout):
|
def run_server(conn, command, debug, timeout):
|
||||||
"""Common code for run_agent and run_git below."""
|
"""Common code for run_agent and run_git below."""
|
||||||
try:
|
try:
|
||||||
handler = protocol.Handler(conn=conn, debug=debug)
|
handler = protocol.Handler(conn=conn, debug=debug)
|
||||||
with server.serve(handler=handler, timeout=timeout) as env:
|
with serve(handler=handler, timeout=timeout) as env:
|
||||||
return server.run_process(command=command, environ=env)
|
return server.run_process(command=command, environ=env)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.info('server stopped')
|
log.info('server stopped')
|
||||||
|
|||||||
1
libagent/ssh/tests/__init__.py
Normal file
1
libagent/ssh/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Unit-tests for this package."""
|
||||||
@@ -8,7 +8,8 @@ import threading
|
|||||||
import mock
|
import mock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .. import protocol, server, util
|
from .. import server, util
|
||||||
|
from ..ssh import protocol
|
||||||
|
|
||||||
|
|
||||||
def test_socket():
|
def test_socket():
|
||||||
@@ -117,12 +118,6 @@ def test_run():
|
|||||||
server.run_process([''], environ={})
|
server.run_process([''], environ={})
|
||||||
|
|
||||||
|
|
||||||
def test_serve_main():
|
|
||||||
handler = protocol.Handler(conn=empty_device())
|
|
||||||
with server.serve(handler=handler, sock_path=None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_remove():
|
def test_remove():
|
||||||
path = 'foo.bar'
|
path = 'foo.bar'
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user