From e19d76398eab94d0a3eb10ea1bb5a813379f9cf3 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Fri, 18 Dec 2015 16:03:50 +0200 Subject: [PATCH] formats: verify public key according to requested ECDSA curve --- trezor_agent/formats.py | 37 +++++++++++++++++++++--------- trezor_agent/tests/test_formats.py | 27 ++++++++++++++++++---- trezor_agent/trezor/client.py | 6 +++-- 3 files changed, 53 insertions(+), 17 deletions(-) diff --git a/trezor_agent/formats.py b/trezor_agent/formats.py index d3f3153..c6fa2c0 100644 --- a/trezor_agent/formats.py +++ b/trezor_agent/formats.py @@ -79,26 +79,43 @@ def parse_pubkey(blob): return result -def decompress_pubkey(pub): - if pub[:1] == b'\x00': +def _decompress_ed25519(pubkey): + if pubkey[:1] == b'\x00': # set by Trezor fsm_msgSignIdentity() and fsm_msgGetPublicKey() - return ed25519.VerifyingKey(pub[1:]) + return ed25519.VerifyingKey(pubkey[1:]) - if pub[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33() + +def _decompress_nist256(pubkey): + if pubkey[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33() curve = ecdsa.NIST256p P = curve.curve.p() A = curve.curve.a() B = curve.curve.b() - x = util.bytes2num(pub[1:33]) + x = util.bytes2num(pubkey[1:33]) beta = pow(int(x * x * x + A * x + B), int((P + 1) // 4), int(P)) - p0 = util.bytes2num(pub[:1]) + p0 = util.bytes2num(pubkey[:1]) y = (P - beta) if ((beta + p0) % 2) else beta point = ecdsa.ellipticcurve.Point(curve.curve, x, y) return ecdsa.VerifyingKey.from_public_point(point, curve=curve, hashfunc=hashfunc) - raise ValueError('invalid {!r}', pub) + + +def decompress_pubkey(pubkey, curve_name): + vk = None + if len(pubkey) == 33: + decompress = { + CURVE_NIST256: _decompress_nist256, + CURVE_ED25519: _decompress_ed25519 + }[curve_name] + vk = decompress(pubkey) + + if not vk: + msg = 'invalid {!s} public key: {!r}'.format(curve_name, pubkey) + raise ValueError(msg) + + return vk def serialize_verifying_key(vk): @@ -119,10 +136,8 @@ def serialize_verifying_key(vk): raise TypeError('unsupported {!r}'.format(vk)) -def export_public_key(pubkey, label): - assert len(pubkey) == 33 - key_type, blob = serialize_verifying_key(decompress_pubkey(pubkey)) - +def export_public_key(vk, label): + key_type, blob = serialize_verifying_key(vk) log.debug('fingerprint: %s', fingerprint(blob)) b64 = base64.b64encode(blob).decode('ascii') return '{} {} {}\n'.format(key_type.decode('ascii'), b64, label) diff --git a/trezor_agent/tests/test_formats.py b/trezor_agent/tests/test_formats.py index 232410a..754c262 100644 --- a/trezor_agent/tests/test_formats.py +++ b/trezor_agent/tests/test_formats.py @@ -35,8 +35,9 @@ def test_parse_public_key(): def test_decompress(): blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575' - result = formats.export_public_key(binascii.unhexlify(blob), label='home') - assert result == _public_key + vk = formats.decompress_pubkey(binascii.unhexlify(blob), + curve_name=formats.CURVE_NIST256) + assert formats.export_public_key(vk, label='home') == _public_key def test_parse_ed25519(): @@ -57,7 +58,7 @@ def test_parse_ed25519(): def test_export_ed25519(): pub = (b'\x00P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4' b'z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91') - vk = formats.decompress_pubkey(pub) + vk = formats.decompress_pubkey(pub, formats.CURVE_ED25519) result = formats.serialize_verifying_key(vk) assert result == (b'ssh-ed25519', b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc' @@ -67,7 +68,25 @@ def test_export_ed25519(): def test_decompress_error(): with pytest.raises(ValueError): - formats.decompress_pubkey('') + formats.decompress_pubkey('', formats.CURVE_NIST256) + + +def test_curve_mismatch(): + # NIST256 public key + blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575' + with pytest.raises(ValueError): + formats.decompress_pubkey(binascii.unhexlify(blob), + curve_name=formats.CURVE_ED25519) + + blob = '00' * 33 # Dummy public key + with pytest.raises(ValueError): + formats.decompress_pubkey(binascii.unhexlify(blob), + curve_name=formats.CURVE_NIST256) + + blob = 'FF' * 33 # Unsupported prefix byte + with pytest.raises(ValueError): + formats.decompress_pubkey(binascii.unhexlify(blob), + curve_name=formats.CURVE_NIST256) def test_serialize_error(): diff --git a/trezor_agent/trezor/client.py b/trezor_agent/trezor/client.py index 893d19b..20d1d7d 100644 --- a/trezor_agent/trezor/client.py +++ b/trezor_agent/trezor/client.py @@ -57,7 +57,8 @@ class Client(object): ecdsa_curve_name=self.curve) pubkey = node.node.public_key - return formats.export_public_key(pubkey=pubkey, label=label) + vk = formats.decompress_pubkey(pubkey=pubkey, curve_name=self.curve) + return formats.export_public_key(vk=vk, label=label) def sign_ssh_challenge(self, label, blob): identity = self.get_identity(label=label) @@ -72,7 +73,8 @@ class Client(object): challenge_visual=visual, ecdsa_curve_name=self.curve) - verifying_key = formats.decompress_pubkey(result.public_key) + verifying_key = formats.decompress_pubkey(pubkey=result.public_key, + curve_name=self.curve) key_type, blob = formats.serialize_verifying_key(verifying_key) assert blob == msg['public_key']['blob'] assert key_type == msg['key_type']