{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Crypto.HPKE.KDF (
    KDF (..),
    HashAlgorithm,
    SHA256 (..),
    SHA384 (..),
    SHA512 (..),
    PRK,
    extractAndExpand,
)
where

import Crypto.Hash.Algorithms (
    HashAlgorithm,
    SHA256 (..),
    SHA384 (..),
    SHA512 (..),
 )
import Crypto.KDF.HKDF (PRK)
import qualified Crypto.KDF.HKDF as HKDF

import Crypto.HPKE.Types

----------------------------------------------------------------

class KDF h where
    labeledExtract :: Suite -> Salt -> Label -> IKM -> PRK h
    labeledExpand :: Suite -> PRK h -> Label -> Info -> Int -> Key

instance KDF SHA256 where
    labeledExtract :: Suite -> Suite -> Suite -> Suite -> PRK SHA256
labeledExtract = Suite -> Suite -> Suite -> Suite -> PRK SHA256
forall a.
HashAlgorithm a =>
Suite -> Suite -> Suite -> Suite -> PRK a
labeledExtract_
    labeledExpand :: Suite -> PRK SHA256 -> Suite -> Suite -> Int -> Suite
labeledExpand = Suite -> PRK SHA256 -> Suite -> Suite -> Int -> Suite
forall a.
HashAlgorithm a =>
Suite -> PRK a -> Suite -> Suite -> Int -> Suite
labeledExpand_

instance KDF SHA384 where
    labeledExtract :: Suite -> Suite -> Suite -> Suite -> PRK SHA384
labeledExtract = Suite -> Suite -> Suite -> Suite -> PRK SHA384
forall a.
HashAlgorithm a =>
Suite -> Suite -> Suite -> Suite -> PRK a
labeledExtract_
    labeledExpand :: Suite -> PRK SHA384 -> Suite -> Suite -> Int -> Suite
labeledExpand = Suite -> PRK SHA384 -> Suite -> Suite -> Int -> Suite
forall a.
HashAlgorithm a =>
Suite -> PRK a -> Suite -> Suite -> Int -> Suite
labeledExpand_

instance KDF SHA512 where
    labeledExtract :: Suite -> Suite -> Suite -> Suite -> PRK SHA512
labeledExtract = Suite -> Suite -> Suite -> Suite -> PRK SHA512
forall a.
HashAlgorithm a =>
Suite -> Suite -> Suite -> Suite -> PRK a
labeledExtract_
    labeledExpand :: Suite -> PRK SHA512 -> Suite -> Suite -> Int -> Suite
labeledExpand = Suite -> PRK SHA512 -> Suite -> Suite -> Int -> Suite
forall a.
HashAlgorithm a =>
Suite -> PRK a -> Suite -> Suite -> Int -> Suite
labeledExpand_

----------------------------------------------------------------

labeledExtract_
    :: HashAlgorithm a => Suite -> Salt -> Label -> IKM -> PRK a
labeledExtract_ :: forall a.
HashAlgorithm a =>
Suite -> Suite -> Suite -> Suite -> PRK a
labeledExtract_ Suite
suite Suite
salt Suite
label Suite
ikm = Suite -> Suite -> PRK a
forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
HKDF.extract Suite
salt Suite
labeled_ikm
  where
    labeled_ikm :: Suite
labeled_ikm = Suite
"HPKE-v1" Suite -> Suite -> Suite
forall a. Semigroup a => a -> a -> a
<> Suite
suite Suite -> Suite -> Suite
forall a. Semigroup a => a -> a -> a
<> Suite
label Suite -> Suite -> Suite
forall a. Semigroup a => a -> a -> a
<> Suite
ikm

labeledExpand_
    :: HashAlgorithm a => Suite -> PRK a -> Label -> Info -> Int -> Key
labeledExpand_ :: forall a.
HashAlgorithm a =>
Suite -> PRK a -> Suite -> Suite -> Int -> Suite
labeledExpand_ Suite
suite PRK a
prk Suite
label Suite
info Int
len = PRK a -> Suite -> Int -> Suite
forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
HKDF.expand PRK a
prk Suite
labeled_info Int
len
  where
    labeled_info :: Suite
labeled_info =
        Int -> Integer -> Suite
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
2 (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Suite -> Suite -> Suite
forall a. Semigroup a => a -> a -> a
<> Suite
"HPKE-v1" Suite -> Suite -> Suite
forall a. Semigroup a => a -> a -> a
<> Suite
suite Suite -> Suite -> Suite
forall a. Semigroup a => a -> a -> a
<> Suite
label Suite -> Suite -> Suite
forall a. Semigroup a => a -> a -> a
<> Suite
info

----------------------------------------------------------------

extractAndExpand
    :: forall h
     . (HashAlgorithm h, KDF h)
    => h -> Suite -> KeyDeriveFunction
extractAndExpand :: forall h.
(HashAlgorithm h, KDF h) =>
h -> Suite -> KeyDeriveFunction
extractAndExpand h
h Suite
suite SharedSecret
dh Suite
kem_context = Suite
shared_secret
  where
    eae_prk :: PRK h
    eae_prk :: PRK h
eae_prk = Suite -> Suite -> Suite -> Suite -> PRK h
forall h. KDF h => Suite -> Suite -> Suite -> Suite -> PRK h
labeledExtract Suite
suite Suite
"" Suite
"eae_prk" (Suite -> PRK h) -> Suite -> PRK h
forall a b. (a -> b) -> a -> b
$ SharedSecret -> Suite
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert SharedSecret
dh
    siz :: Int
siz = h -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize h
h
    shared_secret :: Suite
shared_secret =
        Suite -> PRK h -> Suite -> Suite -> Int -> Suite
forall h. KDF h => Suite -> PRK h -> Suite -> Suite -> Int -> Suite
labeledExpand Suite
suite PRK h
eae_prk Suite
"shared_secret" Suite
kem_context Int
siz