formats: verify public key according to requested ECDSA curve
This commit is contained in:
@@ -79,26 +79,43 @@ def parse_pubkey(blob):
|
||||
return result
|
||||
|
||||
|
||||
def decompress_pubkey(pub):
|
||||
if pub[:1] == b'\x00':
|
||||
def _decompress_ed25519(pubkey):
|
||||
if pubkey[:1] == b'\x00':
|
||||
# set by Trezor fsm_msgSignIdentity() and fsm_msgGetPublicKey()
|
||||
return ed25519.VerifyingKey(pub[1:])
|
||||
return ed25519.VerifyingKey(pubkey[1:])
|
||||
|
||||
if pub[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33()
|
||||
|
||||
def _decompress_nist256(pubkey):
|
||||
if pubkey[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33()
|
||||
curve = ecdsa.NIST256p
|
||||
P = curve.curve.p()
|
||||
A = curve.curve.a()
|
||||
B = curve.curve.b()
|
||||
x = util.bytes2num(pub[1:33])
|
||||
x = util.bytes2num(pubkey[1:33])
|
||||
beta = pow(int(x * x * x + A * x + B), int((P + 1) // 4), int(P))
|
||||
|
||||
p0 = util.bytes2num(pub[:1])
|
||||
p0 = util.bytes2num(pubkey[:1])
|
||||
y = (P - beta) if ((beta + p0) % 2) else beta
|
||||
|
||||
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
|
||||
return ecdsa.VerifyingKey.from_public_point(point, curve=curve,
|
||||
hashfunc=hashfunc)
|
||||
raise ValueError('invalid {!r}', pub)
|
||||
|
||||
|
||||
def decompress_pubkey(pubkey, curve_name):
|
||||
vk = None
|
||||
if len(pubkey) == 33:
|
||||
decompress = {
|
||||
CURVE_NIST256: _decompress_nist256,
|
||||
CURVE_ED25519: _decompress_ed25519
|
||||
}[curve_name]
|
||||
vk = decompress(pubkey)
|
||||
|
||||
if not vk:
|
||||
msg = 'invalid {!s} public key: {!r}'.format(curve_name, pubkey)
|
||||
raise ValueError(msg)
|
||||
|
||||
return vk
|
||||
|
||||
|
||||
def serialize_verifying_key(vk):
|
||||
@@ -119,10 +136,8 @@ def serialize_verifying_key(vk):
|
||||
raise TypeError('unsupported {!r}'.format(vk))
|
||||
|
||||
|
||||
def export_public_key(pubkey, label):
|
||||
assert len(pubkey) == 33
|
||||
key_type, blob = serialize_verifying_key(decompress_pubkey(pubkey))
|
||||
|
||||
def export_public_key(vk, label):
|
||||
key_type, blob = serialize_verifying_key(vk)
|
||||
log.debug('fingerprint: %s', fingerprint(blob))
|
||||
b64 = base64.b64encode(blob).decode('ascii')
|
||||
return '{} {} {}\n'.format(key_type.decode('ascii'), b64, label)
|
||||
|
||||
@@ -35,8 +35,9 @@ def test_parse_public_key():
|
||||
|
||||
def test_decompress():
|
||||
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
|
||||
result = formats.export_public_key(binascii.unhexlify(blob), label='home')
|
||||
assert result == _public_key
|
||||
vk = formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_NIST256)
|
||||
assert formats.export_public_key(vk, label='home') == _public_key
|
||||
|
||||
|
||||
def test_parse_ed25519():
|
||||
@@ -57,7 +58,7 @@ def test_parse_ed25519():
|
||||
def test_export_ed25519():
|
||||
pub = (b'\x00P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4'
|
||||
b'z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91')
|
||||
vk = formats.decompress_pubkey(pub)
|
||||
vk = formats.decompress_pubkey(pub, formats.CURVE_ED25519)
|
||||
result = formats.serialize_verifying_key(vk)
|
||||
assert result == (b'ssh-ed25519',
|
||||
b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc'
|
||||
@@ -67,7 +68,25 @@ def test_export_ed25519():
|
||||
|
||||
def test_decompress_error():
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey('')
|
||||
formats.decompress_pubkey('', formats.CURVE_NIST256)
|
||||
|
||||
|
||||
def test_curve_mismatch():
|
||||
# NIST256 public key
|
||||
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_ED25519)
|
||||
|
||||
blob = '00' * 33 # Dummy public key
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_NIST256)
|
||||
|
||||
blob = 'FF' * 33 # Unsupported prefix byte
|
||||
with pytest.raises(ValueError):
|
||||
formats.decompress_pubkey(binascii.unhexlify(blob),
|
||||
curve_name=formats.CURVE_NIST256)
|
||||
|
||||
|
||||
def test_serialize_error():
|
||||
|
||||
@@ -57,7 +57,8 @@ class Client(object):
|
||||
ecdsa_curve_name=self.curve)
|
||||
|
||||
pubkey = node.node.public_key
|
||||
return formats.export_public_key(pubkey=pubkey, label=label)
|
||||
vk = formats.decompress_pubkey(pubkey=pubkey, curve_name=self.curve)
|
||||
return formats.export_public_key(vk=vk, label=label)
|
||||
|
||||
def sign_ssh_challenge(self, label, blob):
|
||||
identity = self.get_identity(label=label)
|
||||
@@ -72,7 +73,8 @@ class Client(object):
|
||||
challenge_visual=visual,
|
||||
ecdsa_curve_name=self.curve)
|
||||
|
||||
verifying_key = formats.decompress_pubkey(result.public_key)
|
||||
verifying_key = formats.decompress_pubkey(pubkey=result.public_key,
|
||||
curve_name=self.curve)
|
||||
key_type, blob = formats.serialize_verifying_key(verifying_key)
|
||||
assert blob == msg['public_key']['blob']
|
||||
assert key_type == msg['key_type']
|
||||
|
||||
Reference in New Issue
Block a user