trezor: allow expiring cached passphrase
This commit is contained in:
@@ -7,6 +7,7 @@ import mnemonic
|
||||
import semver
|
||||
|
||||
from . import interface
|
||||
from .. import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,7 +47,8 @@ class Trezor(interface.Device):
|
||||
|
||||
conn.callback_PinMatrixRequest = new_handler
|
||||
|
||||
cached_passphrase_ack = None
|
||||
# Remembers the passphrase for an hour.
|
||||
cached_passphrase_ack = util.ExpiringCache(seconds=60*60)
|
||||
cached_state = None
|
||||
|
||||
def _override_passphrase_handler(self, conn):
|
||||
@@ -57,9 +59,10 @@ class Trezor(interface.Device):
|
||||
try:
|
||||
if msg.on_device is True:
|
||||
return self._defs.PassphraseAck()
|
||||
if self.__class__.cached_passphrase_ack:
|
||||
ack = self.__class__.cached_passphrase_ack.get()
|
||||
if ack:
|
||||
log.debug('re-using cached %s passphrase', self)
|
||||
return self.__class__.cached_passphrase_ack
|
||||
return ack
|
||||
|
||||
passphrase = self.ui.get_passphrase()
|
||||
passphrase = mnemonic.Mnemonic.normalize_string(passphrase)
|
||||
@@ -70,7 +73,7 @@ class Trezor(interface.Device):
|
||||
msg = 'Too long passphrase ({} chars)'.format(length)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.__class__.cached_passphrase_ack = ack
|
||||
self.__class__.cached_passphrase_ack.set(ack)
|
||||
return ack
|
||||
except: # noqa
|
||||
conn.init_device()
|
||||
|
||||
@@ -121,3 +121,26 @@ def test_assuan_serialize():
|
||||
assert util.assuan_serialize(b'') == b''
|
||||
assert util.assuan_serialize(b'123\n456') == b'123%0A456'
|
||||
assert util.assuan_serialize(b'\r\n') == b'%0D%0A'
|
||||
|
||||
|
||||
def test_cache():
|
||||
timer = mock.Mock(side_effect=range(7))
|
||||
c = util.ExpiringCache(seconds=2, timer=timer) # t=0
|
||||
assert c.get() is None # t=1
|
||||
obj = 'foo'
|
||||
c.set(obj) # t=2
|
||||
assert c.get() is obj # t=3
|
||||
assert c.get() is obj # t=4
|
||||
assert c.get() is None # t=5
|
||||
assert c.get() is None # t=6
|
||||
|
||||
|
||||
def test_cache_inf():
|
||||
timer = mock.Mock(side_effect=range(6))
|
||||
c = util.ExpiringCache(seconds=float('inf'), timer=timer)
|
||||
obj = 'foo'
|
||||
c.set(obj)
|
||||
assert c.get() is obj
|
||||
assert c.get() is obj
|
||||
assert c.get() is obj
|
||||
assert c.get() is obj
|
||||
|
||||
@@ -5,6 +5,7 @@ import functools
|
||||
import io
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -255,3 +256,25 @@ def assuan_serialize(data):
|
||||
escaped = '%{:02X}'.format(ord(c)).encode('ascii')
|
||||
data = data.replace(c, escaped)
|
||||
return data
|
||||
|
||||
|
||||
class ExpiringCache(object):
|
||||
"""Simple cache with a deadline."""
|
||||
|
||||
def __init__(self, seconds, timer=time.time):
|
||||
"""C-tor."""
|
||||
self.duration = seconds
|
||||
self.timer = timer
|
||||
self.value = None
|
||||
self.set(None)
|
||||
|
||||
def get(self):
|
||||
"""Returns existing value, or None if deadline has expired."""
|
||||
if self.timer() > self.deadline:
|
||||
self.value = None
|
||||
return self.value
|
||||
|
||||
def set(self, value):
|
||||
"""Set new value and reset the deadline for expiration."""
|
||||
self.deadline = self.timer() + self.duration
|
||||
self.value = value
|
||||
|
||||
Reference in New Issue
Block a user