{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Crypto.HPKE.KEM (
    encapGen,
    encapEnv,
    decapEnv,
    genKeyPairP,
)
where

import qualified Control.Exception as E
import Crypto.ECC (
    EllipticCurve (..),
    EllipticCurveDH (..),
    KeyPair (..),
 )
import Crypto.Random (drgNew, withDRG)

import Crypto.HPKE.PublicKey
import Crypto.HPKE.Types

----------------------------------------------------------------

encap
    :: (EllipticCurve group, EllipticCurveDH group)
    => Env group
    -> Encap
encap :: forall group.
(EllipticCurve group, EllipticCurveDH group) =>
Env group -> Encap
encap Env{Maybe (SecretKey group)
Proxy group
SecretKey group
KeyDeriveFunction
envSecretKey :: SecretKey group
envAuthKey :: Maybe (SecretKey group)
envProxy :: Proxy group
envDerive :: KeyDeriveFunction
envSecretKey :: forall group. Env group -> SecretKey group
envAuthKey :: forall group. Env group -> Maybe (SecretKey group)
envProxy :: forall group. Env group -> Proxy group
envDerive :: forall group. Env group -> KeyDeriveFunction
..} enc0 :: EncodedPublicKey
enc0@(EncodedPublicKey ByteString
pkRm) = do
    pkR <- Proxy group -> EncodedPublicKey -> Either HPKEError (Point group)
forall group.
EllipticCurve group =>
Proxy group
-> EncodedPublicKey -> Either HPKEError (PublicKey group)
deserializePublicKey Proxy group
envProxy EncodedPublicKey
enc0
    let skE = SecretKey group
envSecretKey
    dh0 <- ecdh' envProxy skE pkR $ EncapError "encap"
    (dh, pkSm) <- case envAuthKey of
        Maybe (SecretKey group)
Nothing -> (SharedSecret, ByteString)
-> Either HPKEError (SharedSecret, ByteString)
forall a. a -> Either HPKEError a
forall (m :: * -> *) a. Monad m => a -> m a
return (SharedSecret
dh0, ByteString
"")
        Just SecretKey group
skS -> do
            let pkS :: Point group
pkS = Proxy group -> SecretKey group -> Point group
forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> Scalar curve -> Point curve
forall (proxy :: * -> *).
proxy group -> SecretKey group -> Point group
scalarToPoint Proxy group
envProxy SecretKey group
skS
            dh1 <- Proxy group
-> SecretKey group
-> Point group
-> HPKEError
-> Either HPKEError SharedSecret
forall group a.
EllipticCurveDH group =>
Proxy group
-> SecretKey group -> PublicKey group -> a -> Either a SharedSecret
ecdh' Proxy group
envProxy SecretKey group
skS Point group
pkR (HPKEError -> Either HPKEError SharedSecret)
-> HPKEError -> Either HPKEError SharedSecret
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
EncapError String
"encap"
            let EncodedPublicKey pk = serializePublicKey envProxy pkS
            return (dh0 <> dh1, pk)
    let pkE = Proxy group -> SecretKey group -> Point group
forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> Scalar curve -> Point curve
forall (proxy :: * -> *).
proxy group -> SecretKey group -> Point group
scalarToPoint Proxy group
envProxy SecretKey group
skE
    let enc@(EncodedPublicKey pkEm) = serializePublicKey envProxy pkE
        kem_context = ByteString
pkEm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
pkRm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
pkSm
        shared_secret = ScrubbedBytes -> SharedSecret
SharedSecret (ScrubbedBytes -> SharedSecret) -> ScrubbedBytes -> SharedSecret
forall a b. (a -> b) -> a -> b
$ ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ByteString -> ScrubbedBytes) -> ByteString -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ KeyDeriveFunction
envDerive SharedSecret
dh ByteString
kem_context
    return (shared_secret, enc)

encapGen
    :: (EllipticCurve group, EllipticCurveDH group)
    => Proxy group
    -> KeyDeriveFunction
    -> Maybe EncodedSecretKey
    -> IO Encap
encapGen :: forall group.
(EllipticCurve group, EllipticCurveDH group) =>
Proxy group
-> KeyDeriveFunction -> Maybe EncodedSecretKey -> IO Encap
encapGen Proxy group
proxy KeyDeriveFunction
derive Maybe EncodedSecretKey
mskSm = do
    mskS <- case Maybe EncodedSecretKey
mskSm of
        Maybe EncodedSecretKey
Nothing -> Maybe (Scalar group) -> IO (Maybe (Scalar group))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Scalar group) -> IO (Maybe (Scalar group)))
-> Maybe (Scalar group) -> IO (Maybe (Scalar group))
forall a b. (a -> b) -> a -> b
$ Maybe (Scalar group)
forall a. Maybe a
Nothing
        Just EncodedSecretKey
skSm -> case Proxy group -> EncodedSecretKey -> Either HPKEError (Scalar group)
forall group.
EllipticCurve group =>
Proxy group
-> EncodedSecretKey -> Either HPKEError (SecretKey group)
deserializeSecretKey Proxy group
proxy EncodedSecretKey
skSm of
            Left HPKEError
err -> HPKEError -> IO (Maybe (Scalar group))
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO HPKEError
err
            Right Scalar group
x -> Maybe (Scalar group) -> IO (Maybe (Scalar group))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Scalar group) -> IO (Maybe (Scalar group)))
-> Maybe (Scalar group) -> IO (Maybe (Scalar group))
forall a b. (a -> b) -> a -> b
$ Scalar group -> Maybe (Scalar group)
forall a. a -> Maybe a
Just Scalar group
x
    env <- genEnv proxy derive mskS
    return $ encap env

encapEnv
    :: (EllipticCurve group, EllipticCurveDH group)
    => Proxy group
    -> KeyDeriveFunction
    -> EncodedSecretKey
    -> Maybe EncodedSecretKey
    -> Encap
encapEnv :: forall group.
(EllipticCurve group, EllipticCurveDH group) =>
Proxy group
-> KeyDeriveFunction
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> Encap
encapEnv Proxy group
proxy KeyDeriveFunction
derive EncodedSecretKey
skRm Maybe EncodedSecretKey
skSm EncodedPublicKey
enc = do
    env <- Proxy group
-> KeyDeriveFunction
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> Either HPKEError (Env group)
forall group.
EllipticCurve group =>
Proxy group
-> KeyDeriveFunction
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> Either HPKEError (Env group)
newEnvDeserialize Proxy group
proxy KeyDeriveFunction
derive EncodedSecretKey
skRm Maybe EncodedSecretKey
skSm
    encap env enc

----------------------------------------------------------------

decap
    :: (EllipticCurve group, EllipticCurveDH group)
    => Env group
    -> Decap
decap :: forall group.
(EllipticCurve group, EllipticCurveDH group) =>
Env group -> Decap
decap Env{Maybe (SecretKey group)
Proxy group
SecretKey group
KeyDeriveFunction
envSecretKey :: forall group. Env group -> SecretKey group
envAuthKey :: forall group. Env group -> Maybe (SecretKey group)
envProxy :: forall group. Env group -> Proxy group
envDerive :: forall group. Env group -> KeyDeriveFunction
envSecretKey :: SecretKey group
envAuthKey :: Maybe (SecretKey group)
envProxy :: Proxy group
envDerive :: KeyDeriveFunction
..} enc :: EncodedPublicKey
enc@(EncodedPublicKey ByteString
pkEm) = do
    pkE <- Proxy group -> EncodedPublicKey -> Either HPKEError (Point group)
forall group.
EllipticCurve group =>
Proxy group
-> EncodedPublicKey -> Either HPKEError (PublicKey group)
deserializePublicKey Proxy group
envProxy EncodedPublicKey
enc
    let skR = SecretKey group
envSecretKey
    dh0 <- ecdh' envProxy skR pkE $ DecapError "decap"
    (dh, pkSm) <- case envAuthKey of
        Maybe (SecretKey group)
Nothing -> (SharedSecret, ByteString)
-> Either HPKEError (SharedSecret, ByteString)
forall a. a -> Either HPKEError a
forall (m :: * -> *) a. Monad m => a -> m a
return (SharedSecret
dh0, ByteString
"")
        Just SecretKey group
skS -> do
            let pkS :: Point group
pkS = Proxy group -> SecretKey group -> Point group
forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> Scalar curve -> Point curve
forall (proxy :: * -> *).
proxy group -> SecretKey group -> Point group
scalarToPoint Proxy group
envProxy SecretKey group
skS
            dh1 <- Proxy group
-> SecretKey group
-> Point group
-> HPKEError
-> Either HPKEError SharedSecret
forall group a.
EllipticCurveDH group =>
Proxy group
-> SecretKey group -> PublicKey group -> a -> Either a SharedSecret
ecdh' Proxy group
envProxy SecretKey group
skR Point group
pkS (HPKEError -> Either HPKEError SharedSecret)
-> HPKEError -> Either HPKEError SharedSecret
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
EncapError String
"decap"
            let EncodedPublicKey pk = serializePublicKey envProxy pkS
            return (dh0 <> dh1, pk)

    let pkR = Proxy group -> SecretKey group -> Point group
forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> Scalar curve -> Point curve
forall (proxy :: * -> *).
proxy group -> SecretKey group -> Point group
scalarToPoint Proxy group
envProxy SecretKey group
skR
    let EncodedPublicKey pkRm = serializePublicKey envProxy pkR
        kem_context = ByteString
pkEm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
pkRm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
pkSm
        shared_secret = ScrubbedBytes -> SharedSecret
SharedSecret (ScrubbedBytes -> SharedSecret) -> ScrubbedBytes -> SharedSecret
forall a b. (a -> b) -> a -> b
$ ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ByteString -> ScrubbedBytes) -> ByteString -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ KeyDeriveFunction
envDerive SharedSecret
dh ByteString
kem_context
    return shared_secret

decapEnv
    :: (EllipticCurve group, EllipticCurveDH group)
    => Proxy group
    -> KeyDeriveFunction
    -> EncodedSecretKey
    -> Maybe EncodedSecretKey
    -> Decap
decapEnv :: forall group.
(EllipticCurve group, EllipticCurveDH group) =>
Proxy group
-> KeyDeriveFunction
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> Decap
decapEnv Proxy group
proxy KeyDeriveFunction
derive EncodedSecretKey
skRm Maybe EncodedSecretKey
mskSm EncodedPublicKey
enc = do
    env <- Proxy group
-> KeyDeriveFunction
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> Either HPKEError (Env group)
forall group.
EllipticCurve group =>
Proxy group
-> KeyDeriveFunction
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> Either HPKEError (Env group)
newEnvDeserialize Proxy group
proxy KeyDeriveFunction
derive EncodedSecretKey
skRm Maybe EncodedSecretKey
mskSm
    decap env enc

----------------------------------------------------------------

{- FOURMOLU_DISABLE -}
data Env group = Env
    { forall group. Env group -> SecretKey group
envSecretKey :: SecretKey group
    , forall group. Env group -> Maybe (SecretKey group)
envAuthKey   :: Maybe (SecretKey group)
    , forall group. Env group -> Proxy group
envProxy     :: Proxy group
    , forall group. Env group -> KeyDeriveFunction
envDerive    :: KeyDeriveFunction
    }
{- FOURMOLU_ENABLE -}

----------------------------------------------------------------

newEnv
    :: forall group
     . EllipticCurve group
    => KeyDeriveFunction
    -> SecretKey group
    -> Maybe (SecretKey group)
    -> Env group
newEnv :: forall group.
EllipticCurve group =>
KeyDeriveFunction
-> SecretKey group -> Maybe (SecretKey group) -> Env group
newEnv KeyDeriveFunction
derive SecretKey group
skR Maybe (SecretKey group)
mskS =
    Env
        { envSecretKey :: SecretKey group
envSecretKey = SecretKey group
skR
        , envAuthKey :: Maybe (SecretKey group)
envAuthKey = Maybe (SecretKey group)
mskS
        , envProxy :: Proxy group
envProxy = Proxy group
proxy
        , envDerive :: KeyDeriveFunction
envDerive = KeyDeriveFunction
derive
        }
  where
    proxy :: Proxy group
proxy = Proxy group
forall {k} (t :: k). Proxy t
Proxy :: Proxy group

----------------------------------------------------------------

genEnv
    :: EllipticCurve group
    => Proxy group
    -> KeyDeriveFunction
    -> Maybe (SecretKey group)
    -> IO (Env group)
genEnv :: forall group.
EllipticCurve group =>
Proxy group
-> KeyDeriveFunction -> Maybe (SecretKey group) -> IO (Env group)
genEnv Proxy group
proxy KeyDeriveFunction
derive Maybe (SecretKey group)
mskS = do
    (_, sk) <- Proxy group -> IO (Point group, SecretKey group)
forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> IO (Point curve, Scalar curve)
genKeyPairP Proxy group
proxy
    return $ newEnv derive sk mskS

genKeyPairP
    :: EllipticCurve curve
    => proxy curve -> IO (Point curve, Scalar curve)
genKeyPairP :: forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> IO (Point curve, Scalar curve)
genKeyPairP proxy curve
proxy = do
    gen <- IO ChaChaDRG
forall (randomly :: * -> *).
MonadRandom randomly =>
randomly ChaChaDRG
drgNew
    let (KeyPair pk sk, _) = withDRG gen $ curveGenerateKeyPair proxy
    return (pk, sk)

----------------------------------------------------------------

newEnvDeserialize
    :: EllipticCurve group
    => Proxy group
    -> KeyDeriveFunction
    -> EncodedSecretKey
    -> Maybe EncodedSecretKey
    -> Either HPKEError (Env group)
newEnvDeserialize :: forall group.
EllipticCurve group =>
Proxy group
-> KeyDeriveFunction
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> Either HPKEError (Env group)
newEnvDeserialize Proxy group
proxy KeyDeriveFunction
derive EncodedSecretKey
skRm Maybe EncodedSecretKey
mskSm = do
    skR <- Proxy group -> EncodedSecretKey -> Either HPKEError (Scalar group)
forall group.
EllipticCurve group =>
Proxy group
-> EncodedSecretKey -> Either HPKEError (SecretKey group)
deserializeSecretKey Proxy group
proxy EncodedSecretKey
skRm
    mskS <- case mskSm of
        Maybe EncodedSecretKey
Nothing -> Maybe (Scalar group) -> Either HPKEError (Maybe (Scalar group))
forall a b. b -> Either a b
Right (Maybe (Scalar group) -> Either HPKEError (Maybe (Scalar group)))
-> Maybe (Scalar group) -> Either HPKEError (Maybe (Scalar group))
forall a b. (a -> b) -> a -> b
$ Maybe (Scalar group)
forall a. Maybe a
Nothing
        Just EncodedSecretKey
skSm -> Scalar group -> Maybe (Scalar group)
forall a. a -> Maybe a
Just (Scalar group -> Maybe (Scalar group))
-> Either HPKEError (Scalar group)
-> Either HPKEError (Maybe (Scalar group))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Proxy group -> EncodedSecretKey -> Either HPKEError (Scalar group)
forall group.
EllipticCurve group =>
Proxy group
-> EncodedSecretKey -> Either HPKEError (SecretKey group)
deserializeSecretKey Proxy group
proxy EncodedSecretKey
skSm
    return $ newEnv derive skR mskS

----------------------------------------------------------------

ecdh'
    :: EllipticCurveDH group
    => Proxy group
    -> SecretKey group
    -> PublicKey group
    -> a
    -> Either a SharedSecret
ecdh' :: forall group a.
EllipticCurveDH group =>
Proxy group
-> SecretKey group -> PublicKey group -> a -> Either a SharedSecret
ecdh' Proxy group
proxy SecretKey group
sk PublicKey group
pk a
err = case Proxy group
-> SecretKey group
-> PublicKey group
-> CryptoFailable SharedSecret
forall curve (proxy :: * -> *).
EllipticCurveDH curve =>
proxy curve
-> Scalar curve -> Point curve -> CryptoFailable SharedSecret
forall (proxy :: * -> *).
proxy group
-> SecretKey group
-> PublicKey group
-> CryptoFailable SharedSecret
ecdh Proxy group
proxy SecretKey group
sk PublicKey group
pk of
    CryptoPassed SharedSecret
a -> SharedSecret -> Either a SharedSecret
forall a b. b -> Either a b
Right SharedSecret
a
    CryptoFailed CryptoError
_ -> a -> Either a SharedSecret
forall a b. a -> Either a b
Left a
err