From 2eff21f96ccc22b4d36fd0903c3d8bd83f12307c Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Tue, 19 Jan 2016 22:52:52 +0200 Subject: [PATCH] factory: refactor for easier testing --- trezor_agent/tests/test_factory.py | 92 ++++++++++++++++++++++++++++++ trezor_agent/trezor/factory.py | 19 ++++-- 2 files changed, 106 insertions(+), 5 deletions(-) create mode 100644 trezor_agent/tests/test_factory.py diff --git a/trezor_agent/tests/test_factory.py b/trezor_agent/tests/test_factory.py new file mode 100644 index 0000000..5371d41 --- /dev/null +++ b/trezor_agent/tests/test_factory.py @@ -0,0 +1,92 @@ +import mock +import pytest + +from ..trezor import factory + + +def test_load(): + + def single(): + return [0] + + def nothing(): + return [] + + def double(): + return [1, 2] + + assert factory.load(loaders=[single]) == 0 + assert factory.load(loaders=[single, nothing]) == 0 + assert factory.load(loaders=[nothing, single]) == 0 + + with pytest.raises(IOError): + factory.load(loaders=[]) + + with pytest.raises(IOError): + factory.load(loaders=[single, single]) + + with pytest.raises(IOError): + factory.load(loaders=[double]) + + +factory_load_client = factory._load_client # pylint: disable=protected-access + + +def test_load_nothing(): + hid_transport = mock.Mock() + hid_transport.enumerate.return_value = [] + result = factory_load_client( + name=None, + client_type=None, + hid_transport=hid_transport, + passphrase_ack=None, + identity_type=None, + required_version=None) + assert list(result) == [] + + +def create_client_type(version): + conn = mock.Mock() + conn.features = mock.Mock() + major, minor, patch = version.split('.') + conn.features.major_version = major + conn.features.minor_version = minor + conn.features.patch_version = patch + conn.features.revision = b'\x12\x34\x56\x78' + client_type = mock.Mock() + client_type.return_value = conn + return client_type + + +def test_load_single(): + hid_transport = mock.Mock() + hid_transport.enumerate.return_value = [0] + for version in ('1.3.4', '1.3.5', '1.4.0', '2.0.0'): + passphrase_ack = mock.Mock() + client_type = create_client_type(version) + result = factory_load_client( + name='DEVICE_NAME', + client_type=client_type, + hid_transport=hid_transport, + passphrase_ack=passphrase_ack, + identity_type=None, + required_version='>=1.3.4') + client_wrapper, = result + assert client_wrapper.connection is client_type.return_value + assert client_wrapper.device_name == 'DEVICE_NAME' + client_wrapper.connection.callback_PassphraseRequest('MESSAGE') + assert passphrase_ack.mock_calls == [mock.call(passphrase='')] + + +def test_load_old(): + hid_transport = mock.Mock() + hid_transport.enumerate.return_value = [0] + for version in ('1.3.3', '1.2.5', '1.1.0', '0.9.9'): + with pytest.raises(ValueError): + next(factory_load_client( + name='DEVICE_NAME', + client_type=create_client_type(version), + hid_transport=hid_transport, + passphrase_ack=None, + identity_type=None, + required_version='>=1.3.4')) diff --git a/trezor_agent/trezor/factory.py b/trezor_agent/trezor/factory.py index 4b6259b..5a263fa 100644 --- a/trezor_agent/trezor/factory.py +++ b/trezor_agent/trezor/factory.py @@ -68,11 +68,20 @@ def _load_keepkey(): identity_type=IdentityType, required_version='>=1.0.4') +LOADERS = [ + _load_trezor, + _load_keepkey +] -def load(): - devices = list(_load_trezor()) + list(_load_keepkey()) - if len(devices) == 1: - return devices[0] - msg = '{:d} devices found'.format(len(devices)) +def load(loaders=None): + loaders = loaders if loaders is not None else LOADERS + device_list = [] + for loader in loaders: + device_list.extend(loader()) + + if len(device_list) == 1: + return device_list[0] + + msg = '{:d} devices found'.format(len(device_list)) raise IOError(msg)