Add Windows native SSH support
This commit is contained in:
@@ -78,7 +78,8 @@ class UI:
|
||||
def create_default_options_getter():
|
||||
"""Return current TTY and DISPLAY settings for GnuPG pinentry."""
|
||||
options = []
|
||||
if sys.stdin.isatty(): # short-circuit calling `tty`
|
||||
# Windows reports that it has a TTY but throws FileNotFoundError
|
||||
if sys.platform != 'win32' and sys.stdin.isatty(): # short-circuit calling `tty`
|
||||
try:
|
||||
ttyname = subprocess.check_output(args=['tty']).strip()
|
||||
options.append(b'ttyname=' + ttyname)
|
||||
@@ -88,7 +89,8 @@ def create_default_options_getter():
|
||||
display = os.environ.get('DISPLAY')
|
||||
if display is not None:
|
||||
options.append('display={}'.format(display).encode('ascii'))
|
||||
else:
|
||||
# Windows likely doesn't support this anyway
|
||||
elif sys.platform != 'win32':
|
||||
log.warning('DISPLAY not defined')
|
||||
|
||||
log.info('using %s for pinentry options', options)
|
||||
|
||||
@@ -4,24 +4,32 @@ import functools
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import signal
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import configargparse
|
||||
import daemon
|
||||
try:
|
||||
# TODO: Not supported on Windows. Use daemoniker instead?
|
||||
import daemon
|
||||
except ImportError:
|
||||
daemon = None
|
||||
import pkg_resources
|
||||
|
||||
from .. import device, formats, server, util
|
||||
from .. import device, formats, server, util, win_server
|
||||
from . import client, protocol
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
UNIX_SOCKET_TIMEOUT = 0.1
|
||||
|
||||
WIN_PIPE_TIMEOUT = 0.1
|
||||
DEFAULT_TIMEOUT = WIN_PIPE_TIMEOUT if sys.platform == 'win32' else UNIX_SOCKET_TIMEOUT
|
||||
SOCK_TYPE = 'Windows named pipe' if sys.platform == 'win32' else 'UNIX domain socket'
|
||||
|
||||
def ssh_args(conn):
|
||||
"""Create SSH command for connecting specified server."""
|
||||
@@ -35,7 +43,7 @@ def ssh_args(conn):
|
||||
if 'user' in identity:
|
||||
args += ['-l', identity['user']]
|
||||
|
||||
args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile.name)]
|
||||
args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile)]
|
||||
args += ['-o', 'IdentitiesOnly=true']
|
||||
return args + [identity['host']]
|
||||
|
||||
@@ -83,14 +91,14 @@ def create_agent_parser(device_type):
|
||||
default=formats.CURVE_NIST256,
|
||||
help='specify ECDSA curve name: ' + curve_names)
|
||||
p.add_argument('--timeout',
|
||||
default=UNIX_SOCKET_TIMEOUT, type=float,
|
||||
default=DEFAULT_TIMEOUT, type=float,
|
||||
help='timeout for accepting SSH client connections')
|
||||
p.add_argument('--debug', default=False, action='store_true',
|
||||
help='log SSH protocol messages for debugging.')
|
||||
p.add_argument('--log-file', type=str,
|
||||
help='Path to the log file (to be written by the agent).')
|
||||
p.add_argument('--sock-path', type=str,
|
||||
help='Path to the UNIX domain socket of the agent.')
|
||||
help='Path to the ' + SOCK_TYPE + ' of the agent.')
|
||||
|
||||
p.add_argument('--pin-entry-binary', type=str, default='pinentry',
|
||||
help='Path to PIN entry UI helper.')
|
||||
@@ -100,15 +108,18 @@ def create_agent_parser(device_type):
|
||||
help='Expire passphrase from cache after this duration.')
|
||||
|
||||
g = p.add_mutually_exclusive_group()
|
||||
if daemon:
|
||||
g.add_argument('-d', '--daemonize', default=False, action='store_true',
|
||||
help='Daemonize the agent and print its UNIX socket path')
|
||||
help='Daemonize the agent and print its ' + SOCK_TYPE)
|
||||
g.add_argument('-f', '--foreground', default=False, action='store_true',
|
||||
help='Run agent in foreground with specified UNIX socket path')
|
||||
help='Run agent in foreground with specified ' + SOCK_TYPE)
|
||||
g.add_argument('-s', '--shell', default=False, action='store_true',
|
||||
help=('run ${SHELL} as subprocess under SSH agent, allowing '
|
||||
'regular SSH-based tools to be used in the shell'))
|
||||
g.add_argument('-c', '--connect', default=False, action='store_true',
|
||||
help='connect to specified host via SSH')
|
||||
# Windows doesn't have native mosh
|
||||
if sys.platform != 'win32':
|
||||
g.add_argument('--mosh', default=False, action='store_true',
|
||||
help='connect to specified host via using Mosh')
|
||||
|
||||
@@ -119,18 +130,48 @@ def create_agent_parser(device_type):
|
||||
return p
|
||||
|
||||
|
||||
def get_ssh_env(sock_path):
|
||||
ssh_version = subprocess.check_output(['ssh', '-V'],
|
||||
stderr=subprocess.STDOUT)
|
||||
log.debug('local SSH version: %r', ssh_version)
|
||||
return {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
|
||||
|
||||
# Windows doesn't support AF_UNIX yet
|
||||
# https://bugs.python.org/issue33408
|
||||
@contextlib.contextmanager
|
||||
def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
def serve_win(handler, sock_path, timeout=WIN_PIPE_TIMEOUT):
|
||||
"""
|
||||
Start the ssh-agent server on a Windows named pipe.
|
||||
"""
|
||||
environ = get_ssh_env(sock_path)
|
||||
device_mutex = threading.Lock()
|
||||
quit_event = threading.Event()
|
||||
handle_conn = functools.partial(win_server.handle_connection,
|
||||
handler=handler,
|
||||
mutex=device_mutex,
|
||||
quit_event=quit_event)
|
||||
kwargs = dict(pipe_name=sock_path,
|
||||
handle_conn=handle_conn,
|
||||
quit_event=quit_event,
|
||||
timeout=timeout)
|
||||
with server.spawn(win_server.server_thread, kwargs):
|
||||
try:
|
||||
yield environ
|
||||
finally:
|
||||
log.debug('closing server')
|
||||
quit_event.set()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def serve_unix(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
"""
|
||||
Start the ssh-agent server on a UNIX-domain socket.
|
||||
|
||||
If no connection is made during the specified timeout,
|
||||
retry until the context is over.
|
||||
"""
|
||||
ssh_version = subprocess.check_output(['ssh', '-V'],
|
||||
stderr=subprocess.STDOUT)
|
||||
log.debug('local SSH version: %r', ssh_version)
|
||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
environ = get_ssh_env(sock_path)
|
||||
device_mutex = threading.Lock()
|
||||
with server.unix_domain_socket_server(sock_path) as sock:
|
||||
sock.settimeout(timeout)
|
||||
@@ -154,12 +195,15 @@ def run_server(conn, command, sock_path, debug, timeout):
|
||||
ret = 0
|
||||
try:
|
||||
handler = protocol.Handler(conn=conn, debug=debug)
|
||||
with serve(handler=handler, sock_path=sock_path,
|
||||
timeout=timeout) as env:
|
||||
serve_platform = serve_win if sys.platform == 'win32' else serve_unix
|
||||
with serve_platform(handler=handler, sock_path=sock_path, timeout=timeout) as env:
|
||||
if command:
|
||||
ret = server.run_process(command=command, environ=env)
|
||||
else:
|
||||
try:
|
||||
signal.pause() # wait for signal (e.g. SIGINT)
|
||||
except AttributeError:
|
||||
sys.stdin.read() # Windows doesn't support signal.pause
|
||||
except KeyboardInterrupt:
|
||||
log.info('server stopped')
|
||||
return ret
|
||||
@@ -221,10 +265,9 @@ class JustInTimeConnection:
|
||||
"""Store public keys as temporary SSH identity files."""
|
||||
if not self.public_keys_tempfiles:
|
||||
for pk in self.public_keys():
|
||||
f = tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w')
|
||||
with tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w', delete=False, newline='') as f:
|
||||
f.write(pk)
|
||||
f.flush()
|
||||
self.public_keys_tempfiles.append(f)
|
||||
self.public_keys_tempfiles.append(f.name)
|
||||
|
||||
return self.public_keys_tempfiles
|
||||
|
||||
@@ -241,13 +284,16 @@ def _dummy_context():
|
||||
|
||||
def _get_sock_path(args):
|
||||
sock_path = args.sock_path
|
||||
if not sock_path:
|
||||
if args.foreground:
|
||||
log.error('running in foreground mode requires specifying UNIX socket path')
|
||||
sys.exit(1)
|
||||
else:
|
||||
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
||||
if sock_path:
|
||||
return sock_path
|
||||
elif args.foreground:
|
||||
log.error('running in foreground mode requires specifying ' + SOCK_TYPE)
|
||||
sys.exit(1)
|
||||
elif sys.platform == 'win32':
|
||||
suffix = random.choices(string.ascii_letters, k=10)
|
||||
return '\\\\.\pipe\\trezor-ssh-agent-' + ''.join(suffix)
|
||||
else:
|
||||
return tempfile.mktemp(prefix='trezor-ssh-agent-')
|
||||
|
||||
|
||||
@handle_connection_error
|
||||
@@ -286,7 +332,7 @@ def main(device_type):
|
||||
command = ['ssh'] + ssh_args(conn) + args.command
|
||||
elif args.mosh:
|
||||
command = ['mosh'] + mosh_args(conn) + args.command
|
||||
elif args.daemonize:
|
||||
elif daemon and args.daemonize:
|
||||
out = 'SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'.format(sock_path)
|
||||
sys.stdout.write(out)
|
||||
sys.stdout.flush()
|
||||
@@ -300,7 +346,7 @@ def main(device_type):
|
||||
command = os.environ['SHELL']
|
||||
sys.stdin.close()
|
||||
|
||||
if command or args.daemonize or args.foreground:
|
||||
if command or (daemon and args.daemonize) or args.foreground:
|
||||
with context:
|
||||
return run_server(conn=conn, command=command, sock_path=sock_path,
|
||||
debug=args.debug, timeout=args.timeout)
|
||||
|
||||
191
libagent/win_server.py
Normal file
191
libagent/win_server.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Windows named pipe server for ssh-agent implementation."""
|
||||
import logging
|
||||
import pywintypes
|
||||
import struct
|
||||
import threading
|
||||
import win32api
|
||||
import win32event
|
||||
import win32pipe
|
||||
import win32file
|
||||
import winerror
|
||||
|
||||
from . import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
PIPE_BUFFER_SIZE = 64 * 1024
|
||||
|
||||
# Make MemoryView look like a buffer to reuse util.recv
|
||||
class MvBuffer:
|
||||
def __init__(self, mv):
|
||||
self.mv = mv
|
||||
def read(self, n):
|
||||
return self.mv[0:n]
|
||||
|
||||
# Based loosely on https://docs.microsoft.com/en-us/windows/win32/ipc/multithreaded-pipe-server
|
||||
class NamedPipe:
|
||||
__frame_size_size = struct.calcsize('>L')
|
||||
|
||||
def __close(handle):
|
||||
"""Closes a named pipe handle."""
|
||||
if handle == win32file.INVALID_HANDLE_VALUE:
|
||||
return
|
||||
win32file.FlushFileBuffers(handle)
|
||||
win32pipe.DisconnectNamedPipe(handle)
|
||||
win32api.CloseHandle(handle)
|
||||
|
||||
def open(name):
|
||||
"""Opens a named pipe server for receiving connections."""
|
||||
handle = win32pipe.CreateNamedPipe(
|
||||
name,
|
||||
win32pipe.PIPE_ACCESS_DUPLEX | win32file.FILE_FLAG_OVERLAPPED,
|
||||
win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_READMODE_MESSAGE | win32pipe.PIPE_WAIT,
|
||||
win32pipe.PIPE_UNLIMITED_INSTANCES,
|
||||
PIPE_BUFFER_SIZE,
|
||||
PIPE_BUFFER_SIZE,
|
||||
0,
|
||||
None) # Default security attributes
|
||||
|
||||
if handle == win32file.INVALID_HANDLE_VALUE:
|
||||
log.error("CreateNamedPipe failed (%d)", win32api.GetLastError())
|
||||
return None
|
||||
|
||||
try:
|
||||
pending_io = False
|
||||
overlapped = win32file.OVERLAPPED()
|
||||
overlapped.hEvent = win32event.CreateEvent(None, True, True, None)
|
||||
error_code = win32pipe.ConnectNamedPipe(handle, overlapped)
|
||||
if error_code == winerror.ERROR_IO_PENDING:
|
||||
pending_io = True
|
||||
elif error_code != winerror.ERROR_PIPE_CONNECTED or not win32event.SetEvent(overlapped.hEvent):
|
||||
log.error('ConnectNamedPipe failed (%d)', error_code)
|
||||
return None
|
||||
log.debug('waiting for connection on %s', name)
|
||||
return NamedPipe(name, handle, overlapped, pending_io)
|
||||
except:
|
||||
NamedPipe.__close(handle)
|
||||
raise
|
||||
|
||||
def __init__(self, name, handle, overlapped, pending_io):
|
||||
self.name = name
|
||||
self.handle = handle
|
||||
self.overlapped = overlapped
|
||||
self.pending_io = pending_io
|
||||
|
||||
def close(self):
|
||||
"""Close the named pipe."""
|
||||
NamedPipe.__close(self.handle)
|
||||
|
||||
def connect(self, timeout):
|
||||
"""Connect to an SSH client with the specified timeout."""
|
||||
waitHandle = win32event.WaitForSingleObject(
|
||||
self.overlapped.hEvent,
|
||||
timeout)
|
||||
if waitHandle == win32event.WAIT_TIMEOUT:
|
||||
return False
|
||||
if not self.pending_io:
|
||||
return True
|
||||
win32pipe.GetOverlappedResult(
|
||||
self.handle,
|
||||
self.overlapped,
|
||||
False)
|
||||
error_code = win32api.GetLastError()
|
||||
if error_code == winerror.NO_ERROR:
|
||||
return True
|
||||
log.error('GetOverlappedResult failed (%d)', error_code)
|
||||
return False
|
||||
|
||||
def read_frame(self, quit_event):
|
||||
"""Read the request frame from the SSH client."""
|
||||
request_size = None
|
||||
remaining = None
|
||||
buf = MvBuffer(win32file.AllocateReadBuffer(PIPE_BUFFER_SIZE))
|
||||
while True:
|
||||
if quit_event.is_set():
|
||||
return None
|
||||
error_code, _ = win32file.ReadFile(self.handle, buf.mv, self.overlapped)
|
||||
if error_code not in (winerror.NO_ERROR, winerror.ERROR_IO_PENDING, winerror.ERROR_MORE_DATA):
|
||||
log.error('ReadFile failed (%d)', error_code)
|
||||
return None
|
||||
win32event.WaitForSingleObject(self.overlapped.hEvent, win32event.INFINITE)
|
||||
chunk_size = win32pipe.GetOverlappedResult(self.handle, self.overlapped, False)
|
||||
error_code = win32api.GetLastError()
|
||||
if error_code != winerror.NO_ERROR:
|
||||
log.error('GetOverlappedResult failed (%d)', error_code)
|
||||
return None
|
||||
if request_size:
|
||||
remaining -= chunk_size
|
||||
else:
|
||||
request_size, = util.recv(buf, '>L')
|
||||
remaining = request_size - (chunk_size - NamedPipe.__frame_size_size)
|
||||
if remaining <= 0:
|
||||
break
|
||||
return util.recv(buf, request_size)
|
||||
|
||||
def send(self, reply):
|
||||
"""Send the specified reply to the SSH client."""
|
||||
error_code, _ = win32file.WriteFile(self.handle, reply)
|
||||
if error_code == winerror.NO_ERROR:
|
||||
return True
|
||||
log.error('WriteFile failed (%d)', error_code)
|
||||
return False
|
||||
|
||||
|
||||
def handle_connection(pipe, handler, mutex, quit_event):
|
||||
"""
|
||||
Handle a single connection using the specified protocol handler in a loop.
|
||||
|
||||
Since this function may be called concurrently from server_thread,
|
||||
the specified mutex is used to synchronize the device handling.
|
||||
"""
|
||||
log.debug('welcome agent')
|
||||
|
||||
try:
|
||||
while True:
|
||||
if quit_event.is_set():
|
||||
return
|
||||
msg = pipe.read_frame(quit_event)
|
||||
if not msg:
|
||||
return
|
||||
with mutex:
|
||||
reply = handler.handle(msg=msg)
|
||||
if not pipe.send(reply):
|
||||
return
|
||||
except pywintypes.error as e:
|
||||
# Surface errors that aren't related to the client disconnecting
|
||||
if e.args[0] == winerror.ERROR_BROKEN_PIPE:
|
||||
log.debug('goodbye agent')
|
||||
else:
|
||||
raise
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
log.warning('error: %s', e, exc_info=True)
|
||||
finally:
|
||||
pipe.close()
|
||||
|
||||
|
||||
def server_thread(pipe_name, handle_conn, quit_event, timeout):
|
||||
"""Run a Windows server on the specified pipe."""
|
||||
log.debug('server thread started')
|
||||
|
||||
while True:
|
||||
if quit_event.is_set():
|
||||
break
|
||||
# A new pipe instance is necessary for each client
|
||||
pipe = NamedPipe.open(pipe_name)
|
||||
if not pipe:
|
||||
break
|
||||
try:
|
||||
# Poll for a new client connection
|
||||
while True:
|
||||
if quit_event.is_set():
|
||||
break
|
||||
if pipe.connect(timeout * 1000):
|
||||
# Handle connections from SSH concurrently.
|
||||
threading.Thread(target=handle_conn,
|
||||
kwargs=dict(pipe=pipe)).start()
|
||||
break
|
||||
except:
|
||||
pipe.close()
|
||||
raise
|
||||
|
||||
log.debug('server thread stopped')
|
||||
1
setup.py
1
setup.py
@@ -27,6 +27,7 @@ setup(
|
||||
'pymsgbox>=1.0.6',
|
||||
'semver>=2.2',
|
||||
'unidecode>=0.4.20',
|
||||
'pypiwin32'
|
||||
],
|
||||
platforms=['POSIX'],
|
||||
classifiers=[
|
||||
|
||||
Reference in New Issue
Block a user