diff --git a/trezor_agent/tests/test_util.py b/trezor_agent/tests/test_util.py index e417e5d..5a1310b 100644 --- a/trezor_agent/tests/test_util.py +++ b/trezor_agent/tests/test_util.py @@ -1,5 +1,6 @@ import io +import mock import pytest from .. import util @@ -101,3 +102,16 @@ def test_reader(): def test_setup_logging(): util.setup_logging(verbosity=10) + + +def test_memoize(): + f = mock.Mock(side_effect=lambda x: x) + + def func(x): + # mock.Mock doesn't work with functools.wraps() + return f(x) + + g = util.memoize(func) + assert g(1) == g(1) + assert g(1) != g(2) + assert f.mock_calls == [mock.call(1), mock.call(2)] diff --git a/trezor_agent/util.py b/trezor_agent/util.py index c6a43c8..6c397aa 100644 --- a/trezor_agent/util.py +++ b/trezor_agent/util.py @@ -1,6 +1,7 @@ """Various I/O and serialization utilities.""" import binascii import contextlib +import functools import io import logging import struct @@ -185,3 +186,21 @@ def setup_logging(verbosity, **kwargs): levels = [logging.WARNING, logging.INFO, logging.DEBUG] level = levels[min(verbosity, len(levels) - 1)] logging.basicConfig(format=fmt, level=level, **kwargs) + + +def memoize(func): + """Simple caching decorator.""" + cache = {} + + @functools.wraps(func) + def wrapper(*args, **kwargs): + """Caching wrapper.""" + key = (args, tuple(sorted(kwargs.items()))) + if key in cache: + return cache[key] + else: + result = func(*args, **kwargs) + cache[key] = result + return result + + return wrapper