Add Windows native SSH support

This commit is contained in:
Taylor Buchanan
2021-09-18 14:00:20 -05:00
parent 6d55512619
commit 498093f2f6
5 changed files with 275 additions and 34 deletions

View File

@@ -78,7 +78,8 @@ class UI:
def create_default_options_getter():
"""Return current TTY and DISPLAY settings for GnuPG pinentry."""
options = []
if sys.stdin.isatty(): # short-circuit calling `tty`
# Windows reports that it has a TTY but throws FileNotFoundError
if sys.platform != 'win32' and sys.stdin.isatty(): # short-circuit calling `tty`
try:
ttyname = subprocess.check_output(args=['tty']).strip()
options.append(b'ttyname=' + ttyname)
@@ -88,7 +89,8 @@ def create_default_options_getter():
display = os.environ.get('DISPLAY')
if display is not None:
options.append('display={}'.format(display).encode('ascii'))
else:
# Windows likely doesn't support this anyway
elif sys.platform != 'win32':
log.warning('DISPLAY not defined')
log.info('using %s for pinentry options', options)

View File

@@ -4,24 +4,32 @@ import functools
import io
import logging
import os
import random
import re
import signal
import string
import subprocess
import sys
import tempfile
import threading
import configargparse
import daemon
try:
# TODO: Not supported on Windows. Use daemoniker instead?
import daemon
except ImportError:
daemon = None
import pkg_resources
from .. import device, formats, server, util
from .. import device, formats, server, util, win_server
from . import client, protocol
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
WIN_PIPE_TIMEOUT = 0.1
DEFAULT_TIMEOUT = WIN_PIPE_TIMEOUT if sys.platform == 'win32' else UNIX_SOCKET_TIMEOUT
SOCK_TYPE = 'Windows named pipe' if sys.platform == 'win32' else 'UNIX domain socket'
def ssh_args(conn):
"""Create SSH command for connecting specified server."""
@@ -35,7 +43,7 @@ def ssh_args(conn):
if 'user' in identity:
args += ['-l', identity['user']]
args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile.name)]
args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile)]
args += ['-o', 'IdentitiesOnly=true']
return args + [identity['host']]
@@ -83,14 +91,14 @@ def create_agent_parser(device_type):
default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout',
default=UNIX_SOCKET_TIMEOUT, type=float,
default=DEFAULT_TIMEOUT, type=float,
help='timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true',
help='log SSH protocol messages for debugging.')
p.add_argument('--log-file', type=str,
help='Path to the log file (to be written by the agent).')
p.add_argument('--sock-path', type=str,
help='Path to the UNIX domain socket of the agent.')
help='Path to the ' + SOCK_TYPE + ' of the agent.')
p.add_argument('--pin-entry-binary', type=str, default='pinentry',
help='Path to PIN entry UI helper.')
@@ -100,15 +108,18 @@ def create_agent_parser(device_type):
help='Expire passphrase from cache after this duration.')
g = p.add_mutually_exclusive_group()
if daemon:
g.add_argument('-d', '--daemonize', default=False, action='store_true',
help='Daemonize the agent and print its UNIX socket path')
help='Daemonize the agent and print its ' + SOCK_TYPE)
g.add_argument('-f', '--foreground', default=False, action='store_true',
help='Run agent in foreground with specified UNIX socket path')
help='Run agent in foreground with specified ' + SOCK_TYPE)
g.add_argument('-s', '--shell', default=False, action='store_true',
help=('run ${SHELL} as subprocess under SSH agent, allowing '
'regular SSH-based tools to be used in the shell'))
g.add_argument('-c', '--connect', default=False, action='store_true',
help='connect to specified host via SSH')
# Windows doesn't have native mosh
if sys.platform != 'win32':
g.add_argument('--mosh', default=False, action='store_true',
help='connect to specified host via using Mosh')
@@ -119,18 +130,48 @@ def create_agent_parser(device_type):
return p
def get_ssh_env(sock_path):
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
return {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
# Windows doesn't support AF_UNIX yet
# https://bugs.python.org/issue33408
@contextlib.contextmanager
def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT):
def serve_win(handler, sock_path, timeout=WIN_PIPE_TIMEOUT):
"""
Start the ssh-agent server on a Windows named pipe.
"""
environ = get_ssh_env(sock_path)
device_mutex = threading.Lock()
quit_event = threading.Event()
handle_conn = functools.partial(win_server.handle_connection,
handler=handler,
mutex=device_mutex,
quit_event=quit_event)
kwargs = dict(pipe_name=sock_path,
handle_conn=handle_conn,
quit_event=quit_event,
timeout=timeout)
with server.spawn(win_server.server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
@contextlib.contextmanager
def serve_unix(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
environ = get_ssh_env(sock_path)
device_mutex = threading.Lock()
with server.unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
@@ -154,12 +195,15 @@ def run_server(conn, command, sock_path, debug, timeout):
ret = 0
try:
handler = protocol.Handler(conn=conn, debug=debug)
with serve(handler=handler, sock_path=sock_path,
timeout=timeout) as env:
serve_platform = serve_win if sys.platform == 'win32' else serve_unix
with serve_platform(handler=handler, sock_path=sock_path, timeout=timeout) as env:
if command:
ret = server.run_process(command=command, environ=env)
else:
try:
signal.pause() # wait for signal (e.g. SIGINT)
except AttributeError:
sys.stdin.read() # Windows doesn't support signal.pause
except KeyboardInterrupt:
log.info('server stopped')
return ret
@@ -221,10 +265,9 @@ class JustInTimeConnection:
"""Store public keys as temporary SSH identity files."""
if not self.public_keys_tempfiles:
for pk in self.public_keys():
f = tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w')
with tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w', delete=False, newline='') as f:
f.write(pk)
f.flush()
self.public_keys_tempfiles.append(f)
self.public_keys_tempfiles.append(f.name)
return self.public_keys_tempfiles
@@ -241,13 +284,16 @@ def _dummy_context():
def _get_sock_path(args):
sock_path = args.sock_path
if not sock_path:
if args.foreground:
log.error('running in foreground mode requires specifying UNIX socket path')
sys.exit(1)
else:
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
if sock_path:
return sock_path
elif args.foreground:
log.error('running in foreground mode requires specifying ' + SOCK_TYPE)
sys.exit(1)
elif sys.platform == 'win32':
suffix = random.choices(string.ascii_letters, k=10)
return '\\\\.\pipe\\trezor-ssh-agent-' + ''.join(suffix)
else:
return tempfile.mktemp(prefix='trezor-ssh-agent-')
@handle_connection_error
@@ -286,7 +332,7 @@ def main(device_type):
command = ['ssh'] + ssh_args(conn) + args.command
elif args.mosh:
command = ['mosh'] + mosh_args(conn) + args.command
elif args.daemonize:
elif daemon and args.daemonize:
out = 'SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'.format(sock_path)
sys.stdout.write(out)
sys.stdout.flush()
@@ -300,7 +346,7 @@ def main(device_type):
command = os.environ['SHELL']
sys.stdin.close()
if command or args.daemonize or args.foreground:
if command or (daemon and args.daemonize) or args.foreground:
with context:
return run_server(conn=conn, command=command, sock_path=sock_path,
debug=args.debug, timeout=args.timeout)

191
libagent/win_server.py Normal file
View File

@@ -0,0 +1,191 @@
"""Windows named pipe server for ssh-agent implementation."""
import logging
import pywintypes
import struct
import threading
import win32api
import win32event
import win32pipe
import win32file
import winerror
from . import util
log = logging.getLogger(__name__)
PIPE_BUFFER_SIZE = 64 * 1024
# Make MemoryView look like a buffer to reuse util.recv
class MvBuffer:
def __init__(self, mv):
self.mv = mv
def read(self, n):
return self.mv[0:n]
# Based loosely on https://docs.microsoft.com/en-us/windows/win32/ipc/multithreaded-pipe-server
class NamedPipe:
__frame_size_size = struct.calcsize('>L')
def __close(handle):
"""Closes a named pipe handle."""
if handle == win32file.INVALID_HANDLE_VALUE:
return
win32file.FlushFileBuffers(handle)
win32pipe.DisconnectNamedPipe(handle)
win32api.CloseHandle(handle)
def open(name):
"""Opens a named pipe server for receiving connections."""
handle = win32pipe.CreateNamedPipe(
name,
win32pipe.PIPE_ACCESS_DUPLEX | win32file.FILE_FLAG_OVERLAPPED,
win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_READMODE_MESSAGE | win32pipe.PIPE_WAIT,
win32pipe.PIPE_UNLIMITED_INSTANCES,
PIPE_BUFFER_SIZE,
PIPE_BUFFER_SIZE,
0,
None) # Default security attributes
if handle == win32file.INVALID_HANDLE_VALUE:
log.error("CreateNamedPipe failed (%d)", win32api.GetLastError())
return None
try:
pending_io = False
overlapped = win32file.OVERLAPPED()
overlapped.hEvent = win32event.CreateEvent(None, True, True, None)
error_code = win32pipe.ConnectNamedPipe(handle, overlapped)
if error_code == winerror.ERROR_IO_PENDING:
pending_io = True
elif error_code != winerror.ERROR_PIPE_CONNECTED or not win32event.SetEvent(overlapped.hEvent):
log.error('ConnectNamedPipe failed (%d)', error_code)
return None
log.debug('waiting for connection on %s', name)
return NamedPipe(name, handle, overlapped, pending_io)
except:
NamedPipe.__close(handle)
raise
def __init__(self, name, handle, overlapped, pending_io):
self.name = name
self.handle = handle
self.overlapped = overlapped
self.pending_io = pending_io
def close(self):
"""Close the named pipe."""
NamedPipe.__close(self.handle)
def connect(self, timeout):
"""Connect to an SSH client with the specified timeout."""
waitHandle = win32event.WaitForSingleObject(
self.overlapped.hEvent,
timeout)
if waitHandle == win32event.WAIT_TIMEOUT:
return False
if not self.pending_io:
return True
win32pipe.GetOverlappedResult(
self.handle,
self.overlapped,
False)
error_code = win32api.GetLastError()
if error_code == winerror.NO_ERROR:
return True
log.error('GetOverlappedResult failed (%d)', error_code)
return False
def read_frame(self, quit_event):
"""Read the request frame from the SSH client."""
request_size = None
remaining = None
buf = MvBuffer(win32file.AllocateReadBuffer(PIPE_BUFFER_SIZE))
while True:
if quit_event.is_set():
return None
error_code, _ = win32file.ReadFile(self.handle, buf.mv, self.overlapped)
if error_code not in (winerror.NO_ERROR, winerror.ERROR_IO_PENDING, winerror.ERROR_MORE_DATA):
log.error('ReadFile failed (%d)', error_code)
return None
win32event.WaitForSingleObject(self.overlapped.hEvent, win32event.INFINITE)
chunk_size = win32pipe.GetOverlappedResult(self.handle, self.overlapped, False)
error_code = win32api.GetLastError()
if error_code != winerror.NO_ERROR:
log.error('GetOverlappedResult failed (%d)', error_code)
return None
if request_size:
remaining -= chunk_size
else:
request_size, = util.recv(buf, '>L')
remaining = request_size - (chunk_size - NamedPipe.__frame_size_size)
if remaining <= 0:
break
return util.recv(buf, request_size)
def send(self, reply):
"""Send the specified reply to the SSH client."""
error_code, _ = win32file.WriteFile(self.handle, reply)
if error_code == winerror.NO_ERROR:
return True
log.error('WriteFile failed (%d)', error_code)
return False
def handle_connection(pipe, handler, mutex, quit_event):
"""
Handle a single connection using the specified protocol handler in a loop.
Since this function may be called concurrently from server_thread,
the specified mutex is used to synchronize the device handling.
"""
log.debug('welcome agent')
try:
while True:
if quit_event.is_set():
return
msg = pipe.read_frame(quit_event)
if not msg:
return
with mutex:
reply = handler.handle(msg=msg)
if not pipe.send(reply):
return
except pywintypes.error as e:
# Surface errors that aren't related to the client disconnecting
if e.args[0] == winerror.ERROR_BROKEN_PIPE:
log.debug('goodbye agent')
else:
raise
except Exception as e: # pylint: disable=broad-except
log.warning('error: %s', e, exc_info=True)
finally:
pipe.close()
def server_thread(pipe_name, handle_conn, quit_event, timeout):
"""Run a Windows server on the specified pipe."""
log.debug('server thread started')
while True:
if quit_event.is_set():
break
# A new pipe instance is necessary for each client
pipe = NamedPipe.open(pipe_name)
if not pipe:
break
try:
# Poll for a new client connection
while True:
if quit_event.is_set():
break
if pipe.connect(timeout * 1000):
# Handle connections from SSH concurrently.
threading.Thread(target=handle_conn,
kwargs=dict(pipe=pipe)).start()
break
except:
pipe.close()
raise
log.debug('server thread stopped')

View File

@@ -27,6 +27,7 @@ setup(
'pymsgbox>=1.0.6',
'semver>=2.2',
'unidecode>=0.4.20',
'pypiwin32'
],
platforms=['POSIX'],
classifiers=[

View File

@@ -14,6 +14,7 @@ deps=
semver
pydocstyle
isort<5
pypiwin32
commands=
pycodestyle libagent
isort --skip-glob .tox -c -rc libagent