Add Windows native SSH support

This commit is contained in:
Taylor Buchanan
2021-09-18 14:00:20 -05:00
parent 6d55512619
commit 498093f2f6
5 changed files with 275 additions and 34 deletions

View File

@@ -78,7 +78,8 @@ class UI:
def create_default_options_getter(): def create_default_options_getter():
"""Return current TTY and DISPLAY settings for GnuPG pinentry.""" """Return current TTY and DISPLAY settings for GnuPG pinentry."""
options = [] options = []
if sys.stdin.isatty(): # short-circuit calling `tty` # Windows reports that it has a TTY but throws FileNotFoundError
if sys.platform != 'win32' and sys.stdin.isatty(): # short-circuit calling `tty`
try: try:
ttyname = subprocess.check_output(args=['tty']).strip() ttyname = subprocess.check_output(args=['tty']).strip()
options.append(b'ttyname=' + ttyname) options.append(b'ttyname=' + ttyname)
@@ -88,7 +89,8 @@ def create_default_options_getter():
display = os.environ.get('DISPLAY') display = os.environ.get('DISPLAY')
if display is not None: if display is not None:
options.append('display={}'.format(display).encode('ascii')) options.append('display={}'.format(display).encode('ascii'))
else: # Windows likely doesn't support this anyway
elif sys.platform != 'win32':
log.warning('DISPLAY not defined') log.warning('DISPLAY not defined')
log.info('using %s for pinentry options', options) log.info('using %s for pinentry options', options)

View File

@@ -4,24 +4,32 @@ import functools
import io import io
import logging import logging
import os import os
import random
import re import re
import signal import signal
import string
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import threading import threading
import configargparse import configargparse
try:
# TODO: Not supported on Windows. Use daemoniker instead?
import daemon import daemon
except ImportError:
daemon = None
import pkg_resources import pkg_resources
from .. import device, formats, server, util from .. import device, formats, server, util, win_server
from . import client, protocol from . import client, protocol
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1 UNIX_SOCKET_TIMEOUT = 0.1
WIN_PIPE_TIMEOUT = 0.1
DEFAULT_TIMEOUT = WIN_PIPE_TIMEOUT if sys.platform == 'win32' else UNIX_SOCKET_TIMEOUT
SOCK_TYPE = 'Windows named pipe' if sys.platform == 'win32' else 'UNIX domain socket'
def ssh_args(conn): def ssh_args(conn):
"""Create SSH command for connecting specified server.""" """Create SSH command for connecting specified server."""
@@ -35,7 +43,7 @@ def ssh_args(conn):
if 'user' in identity: if 'user' in identity:
args += ['-l', identity['user']] args += ['-l', identity['user']]
args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile.name)] args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile)]
args += ['-o', 'IdentitiesOnly=true'] args += ['-o', 'IdentitiesOnly=true']
return args + [identity['host']] return args + [identity['host']]
@@ -83,14 +91,14 @@ def create_agent_parser(device_type):
default=formats.CURVE_NIST256, default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curve_names) help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout', p.add_argument('--timeout',
default=UNIX_SOCKET_TIMEOUT, type=float, default=DEFAULT_TIMEOUT, type=float,
help='timeout for accepting SSH client connections') help='timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true', p.add_argument('--debug', default=False, action='store_true',
help='log SSH protocol messages for debugging.') help='log SSH protocol messages for debugging.')
p.add_argument('--log-file', type=str, p.add_argument('--log-file', type=str,
help='Path to the log file (to be written by the agent).') help='Path to the log file (to be written by the agent).')
p.add_argument('--sock-path', type=str, p.add_argument('--sock-path', type=str,
help='Path to the UNIX domain socket of the agent.') help='Path to the ' + SOCK_TYPE + ' of the agent.')
p.add_argument('--pin-entry-binary', type=str, default='pinentry', p.add_argument('--pin-entry-binary', type=str, default='pinentry',
help='Path to PIN entry UI helper.') help='Path to PIN entry UI helper.')
@@ -100,15 +108,18 @@ def create_agent_parser(device_type):
help='Expire passphrase from cache after this duration.') help='Expire passphrase from cache after this duration.')
g = p.add_mutually_exclusive_group() g = p.add_mutually_exclusive_group()
if daemon:
g.add_argument('-d', '--daemonize', default=False, action='store_true', g.add_argument('-d', '--daemonize', default=False, action='store_true',
help='Daemonize the agent and print its UNIX socket path') help='Daemonize the agent and print its ' + SOCK_TYPE)
g.add_argument('-f', '--foreground', default=False, action='store_true', g.add_argument('-f', '--foreground', default=False, action='store_true',
help='Run agent in foreground with specified UNIX socket path') help='Run agent in foreground with specified ' + SOCK_TYPE)
g.add_argument('-s', '--shell', default=False, action='store_true', g.add_argument('-s', '--shell', default=False, action='store_true',
help=('run ${SHELL} as subprocess under SSH agent, allowing ' help=('run ${SHELL} as subprocess under SSH agent, allowing '
'regular SSH-based tools to be used in the shell')) 'regular SSH-based tools to be used in the shell'))
g.add_argument('-c', '--connect', default=False, action='store_true', g.add_argument('-c', '--connect', default=False, action='store_true',
help='connect to specified host via SSH') help='connect to specified host via SSH')
# Windows doesn't have native mosh
if sys.platform != 'win32':
g.add_argument('--mosh', default=False, action='store_true', g.add_argument('--mosh', default=False, action='store_true',
help='connect to specified host via using Mosh') help='connect to specified host via using Mosh')
@@ -119,18 +130,48 @@ def create_agent_parser(device_type):
return p return p
def get_ssh_env(sock_path):
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
return {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
# Windows doesn't support AF_UNIX yet
# https://bugs.python.org/issue33408
@contextlib.contextmanager @contextlib.contextmanager
def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT): def serve_win(handler, sock_path, timeout=WIN_PIPE_TIMEOUT):
"""
Start the ssh-agent server on a Windows named pipe.
"""
environ = get_ssh_env(sock_path)
device_mutex = threading.Lock()
quit_event = threading.Event()
handle_conn = functools.partial(win_server.handle_connection,
handler=handler,
mutex=device_mutex,
quit_event=quit_event)
kwargs = dict(pipe_name=sock_path,
handle_conn=handle_conn,
quit_event=quit_event,
timeout=timeout)
with server.spawn(win_server.server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
@contextlib.contextmanager
def serve_unix(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT):
""" """
Start the ssh-agent server on a UNIX-domain socket. Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout, If no connection is made during the specified timeout,
retry until the context is over. retry until the context is over.
""" """
ssh_version = subprocess.check_output(['ssh', '-V'], environ = get_ssh_env(sock_path)
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
device_mutex = threading.Lock() device_mutex = threading.Lock()
with server.unix_domain_socket_server(sock_path) as sock: with server.unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout) sock.settimeout(timeout)
@@ -154,12 +195,15 @@ def run_server(conn, command, sock_path, debug, timeout):
ret = 0 ret = 0
try: try:
handler = protocol.Handler(conn=conn, debug=debug) handler = protocol.Handler(conn=conn, debug=debug)
with serve(handler=handler, sock_path=sock_path, serve_platform = serve_win if sys.platform == 'win32' else serve_unix
timeout=timeout) as env: with serve_platform(handler=handler, sock_path=sock_path, timeout=timeout) as env:
if command: if command:
ret = server.run_process(command=command, environ=env) ret = server.run_process(command=command, environ=env)
else: else:
try:
signal.pause() # wait for signal (e.g. SIGINT) signal.pause() # wait for signal (e.g. SIGINT)
except AttributeError:
sys.stdin.read() # Windows doesn't support signal.pause
except KeyboardInterrupt: except KeyboardInterrupt:
log.info('server stopped') log.info('server stopped')
return ret return ret
@@ -221,10 +265,9 @@ class JustInTimeConnection:
"""Store public keys as temporary SSH identity files.""" """Store public keys as temporary SSH identity files."""
if not self.public_keys_tempfiles: if not self.public_keys_tempfiles:
for pk in self.public_keys(): for pk in self.public_keys():
f = tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w') with tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w', delete=False, newline='') as f:
f.write(pk) f.write(pk)
f.flush() self.public_keys_tempfiles.append(f.name)
self.public_keys_tempfiles.append(f)
return self.public_keys_tempfiles return self.public_keys_tempfiles
@@ -241,13 +284,16 @@ def _dummy_context():
def _get_sock_path(args): def _get_sock_path(args):
sock_path = args.sock_path sock_path = args.sock_path
if not sock_path: if sock_path:
if args.foreground:
log.error('running in foreground mode requires specifying UNIX socket path')
sys.exit(1)
else:
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
return sock_path return sock_path
elif args.foreground:
log.error('running in foreground mode requires specifying ' + SOCK_TYPE)
sys.exit(1)
elif sys.platform == 'win32':
suffix = random.choices(string.ascii_letters, k=10)
return '\\\\.\pipe\\trezor-ssh-agent-' + ''.join(suffix)
else:
return tempfile.mktemp(prefix='trezor-ssh-agent-')
@handle_connection_error @handle_connection_error
@@ -286,7 +332,7 @@ def main(device_type):
command = ['ssh'] + ssh_args(conn) + args.command command = ['ssh'] + ssh_args(conn) + args.command
elif args.mosh: elif args.mosh:
command = ['mosh'] + mosh_args(conn) + args.command command = ['mosh'] + mosh_args(conn) + args.command
elif args.daemonize: elif daemon and args.daemonize:
out = 'SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'.format(sock_path) out = 'SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'.format(sock_path)
sys.stdout.write(out) sys.stdout.write(out)
sys.stdout.flush() sys.stdout.flush()
@@ -300,7 +346,7 @@ def main(device_type):
command = os.environ['SHELL'] command = os.environ['SHELL']
sys.stdin.close() sys.stdin.close()
if command or args.daemonize or args.foreground: if command or (daemon and args.daemonize) or args.foreground:
with context: with context:
return run_server(conn=conn, command=command, sock_path=sock_path, return run_server(conn=conn, command=command, sock_path=sock_path,
debug=args.debug, timeout=args.timeout) debug=args.debug, timeout=args.timeout)

191
libagent/win_server.py Normal file
View File

@@ -0,0 +1,191 @@
"""Windows named pipe server for ssh-agent implementation."""
import logging
import pywintypes
import struct
import threading
import win32api
import win32event
import win32pipe
import win32file
import winerror
from . import util
log = logging.getLogger(__name__)
PIPE_BUFFER_SIZE = 64 * 1024
# Make MemoryView look like a buffer to reuse util.recv
class MvBuffer:
def __init__(self, mv):
self.mv = mv
def read(self, n):
return self.mv[0:n]
# Based loosely on https://docs.microsoft.com/en-us/windows/win32/ipc/multithreaded-pipe-server
class NamedPipe:
__frame_size_size = struct.calcsize('>L')
def __close(handle):
"""Closes a named pipe handle."""
if handle == win32file.INVALID_HANDLE_VALUE:
return
win32file.FlushFileBuffers(handle)
win32pipe.DisconnectNamedPipe(handle)
win32api.CloseHandle(handle)
def open(name):
"""Opens a named pipe server for receiving connections."""
handle = win32pipe.CreateNamedPipe(
name,
win32pipe.PIPE_ACCESS_DUPLEX | win32file.FILE_FLAG_OVERLAPPED,
win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_READMODE_MESSAGE | win32pipe.PIPE_WAIT,
win32pipe.PIPE_UNLIMITED_INSTANCES,
PIPE_BUFFER_SIZE,
PIPE_BUFFER_SIZE,
0,
None) # Default security attributes
if handle == win32file.INVALID_HANDLE_VALUE:
log.error("CreateNamedPipe failed (%d)", win32api.GetLastError())
return None
try:
pending_io = False
overlapped = win32file.OVERLAPPED()
overlapped.hEvent = win32event.CreateEvent(None, True, True, None)
error_code = win32pipe.ConnectNamedPipe(handle, overlapped)
if error_code == winerror.ERROR_IO_PENDING:
pending_io = True
elif error_code != winerror.ERROR_PIPE_CONNECTED or not win32event.SetEvent(overlapped.hEvent):
log.error('ConnectNamedPipe failed (%d)', error_code)
return None
log.debug('waiting for connection on %s', name)
return NamedPipe(name, handle, overlapped, pending_io)
except:
NamedPipe.__close(handle)
raise
def __init__(self, name, handle, overlapped, pending_io):
self.name = name
self.handle = handle
self.overlapped = overlapped
self.pending_io = pending_io
def close(self):
"""Close the named pipe."""
NamedPipe.__close(self.handle)
def connect(self, timeout):
"""Connect to an SSH client with the specified timeout."""
waitHandle = win32event.WaitForSingleObject(
self.overlapped.hEvent,
timeout)
if waitHandle == win32event.WAIT_TIMEOUT:
return False
if not self.pending_io:
return True
win32pipe.GetOverlappedResult(
self.handle,
self.overlapped,
False)
error_code = win32api.GetLastError()
if error_code == winerror.NO_ERROR:
return True
log.error('GetOverlappedResult failed (%d)', error_code)
return False
def read_frame(self, quit_event):
"""Read the request frame from the SSH client."""
request_size = None
remaining = None
buf = MvBuffer(win32file.AllocateReadBuffer(PIPE_BUFFER_SIZE))
while True:
if quit_event.is_set():
return None
error_code, _ = win32file.ReadFile(self.handle, buf.mv, self.overlapped)
if error_code not in (winerror.NO_ERROR, winerror.ERROR_IO_PENDING, winerror.ERROR_MORE_DATA):
log.error('ReadFile failed (%d)', error_code)
return None
win32event.WaitForSingleObject(self.overlapped.hEvent, win32event.INFINITE)
chunk_size = win32pipe.GetOverlappedResult(self.handle, self.overlapped, False)
error_code = win32api.GetLastError()
if error_code != winerror.NO_ERROR:
log.error('GetOverlappedResult failed (%d)', error_code)
return None
if request_size:
remaining -= chunk_size
else:
request_size, = util.recv(buf, '>L')
remaining = request_size - (chunk_size - NamedPipe.__frame_size_size)
if remaining <= 0:
break
return util.recv(buf, request_size)
def send(self, reply):
"""Send the specified reply to the SSH client."""
error_code, _ = win32file.WriteFile(self.handle, reply)
if error_code == winerror.NO_ERROR:
return True
log.error('WriteFile failed (%d)', error_code)
return False
def handle_connection(pipe, handler, mutex, quit_event):
"""
Handle a single connection using the specified protocol handler in a loop.
Since this function may be called concurrently from server_thread,
the specified mutex is used to synchronize the device handling.
"""
log.debug('welcome agent')
try:
while True:
if quit_event.is_set():
return
msg = pipe.read_frame(quit_event)
if not msg:
return
with mutex:
reply = handler.handle(msg=msg)
if not pipe.send(reply):
return
except pywintypes.error as e:
# Surface errors that aren't related to the client disconnecting
if e.args[0] == winerror.ERROR_BROKEN_PIPE:
log.debug('goodbye agent')
else:
raise
except Exception as e: # pylint: disable=broad-except
log.warning('error: %s', e, exc_info=True)
finally:
pipe.close()
def server_thread(pipe_name, handle_conn, quit_event, timeout):
"""Run a Windows server on the specified pipe."""
log.debug('server thread started')
while True:
if quit_event.is_set():
break
# A new pipe instance is necessary for each client
pipe = NamedPipe.open(pipe_name)
if not pipe:
break
try:
# Poll for a new client connection
while True:
if quit_event.is_set():
break
if pipe.connect(timeout * 1000):
# Handle connections from SSH concurrently.
threading.Thread(target=handle_conn,
kwargs=dict(pipe=pipe)).start()
break
except:
pipe.close()
raise
log.debug('server thread stopped')

View File

@@ -27,6 +27,7 @@ setup(
'pymsgbox>=1.0.6', 'pymsgbox>=1.0.6',
'semver>=2.2', 'semver>=2.2',
'unidecode>=0.4.20', 'unidecode>=0.4.20',
'pypiwin32'
], ],
platforms=['POSIX'], platforms=['POSIX'],
classifiers=[ classifiers=[

View File

@@ -14,6 +14,7 @@ deps=
semver semver
pydocstyle pydocstyle
isort<5 isort<5
pypiwin32
commands= commands=
pycodestyle libagent pycodestyle libagent
isort --skip-glob .tox -c -rc libagent isort --skip-glob .tox -c -rc libagent