{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Token (
Config,
defaultConfig,
interval,
tokenLifetime,
threadName,
TokenManager,
spawnTokenManager,
killTokenManager,
encryptToken,
decryptToken,
) where
import Control.Concurrent
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types (AEADMode (..), AuthTag (..))
import qualified Crypto.Cipher.Types as C
import Crypto.Error (maybeCryptoError, throwCryptoError)
import Crypto.Random (getRandomBytes)
import Data.Array.IO
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.IORef as I
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.Conc.Sync (labelThread)
import Network.ByteOrder
type Index = Word16
type Counter = Word64
data Config = Config
{ Config -> Int
interval :: Int
, Config -> Int
tokenLifetime :: Int
, Config -> String
threadName :: String
}
deriving (Config -> Config -> Bool
(Config -> Config -> Bool)
-> (Config -> Config -> Bool) -> Eq Config
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
/= :: Config -> Config -> Bool
Eq, Int -> Config -> ShowS
[Config] -> ShowS
Config -> String
(Int -> Config -> ShowS)
-> (Config -> String) -> ([Config] -> ShowS) -> Show Config
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Config -> ShowS
showsPrec :: Int -> Config -> ShowS
$cshow :: Config -> String
show :: Config -> String
$cshowList :: [Config] -> ShowS
showList :: [Config] -> ShowS
Show)
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
Config
{ interval :: Int
interval = Int
1800
, tokenLifetime :: Int
tokenLifetime = Int
7200
, threadName :: String
threadName = String
"Crypto token manager"
}
data TokenManager = TokenManager
{ :: Header
, TokenManager -> IO (Secret, Word16)
getEncryptSecret :: IO (Secret, Index)
, TokenManager -> Word16 -> IO Secret
getDecryptSecret :: Index -> IO Secret
, TokenManager -> ThreadId
threadId :: ThreadId
}
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager Config{Int
String
interval :: Config -> Int
tokenLifetime :: Config -> Int
threadName :: Config -> String
interval :: Int
tokenLifetime :: Int
threadName :: String
..} = do
emp <- IO Secret
emptySecret
let lim = Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
tokenLifetime Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
interval)
arr <- newArray (0, lim - 1) emp
ent <- generateSecret
writeArray arr 0 ent
ref <- I.newIORef 0
tid <- forkIO $ loop arr ref
labelThread tid threadName
msk <- newHeaderMask
return $ TokenManager msk (readCurrentSecret arr ref) (readSecret arr) tid
where
loop :: IOArray Word16 Secret -> IORef Word16 -> IO b
loop IOArray Word16 Secret
arr IORef Word16
ref = do
Int -> IO ()
threadDelay (Int
interval Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000000)
IOArray Word16 Secret -> IORef Word16 -> IO ()
update IOArray Word16 Secret
arr IORef Word16
ref
IOArray Word16 Secret -> IORef Word16 -> IO b
loop IOArray Word16 Secret
arr IORef Word16
ref
update :: IOArray Index Secret -> I.IORef Index -> IO ()
update :: IOArray Word16 Secret -> IORef Word16 -> IO ()
update IOArray Word16 Secret
arr IORef Word16
ref = do
idx0 <- IORef Word16 -> IO Word16
forall a. IORef a -> IO a
I.readIORef IORef Word16
ref
(_, n) <- getBounds arr
let idx = (Word16
idx0 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
1) Word16 -> Word16 -> Word16
forall a. Integral a => a -> a -> a
`mod` (Word16
n Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
1)
sec <- generateSecret
writeArray arr idx sec
I.writeIORef ref idx
killTokenManager :: TokenManager -> IO ()
killTokenManager :: TokenManager -> IO ()
killTokenManager TokenManager{IO (Secret, Word16)
Header
ThreadId
Word16 -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Word16)
getDecryptSecret :: TokenManager -> Word16 -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Word16)
getDecryptSecret :: Word16 -> IO Secret
threadId :: ThreadId
..} = ThreadId -> IO ()
killThread ThreadId
threadId
readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret :: IOArray Word16 Secret -> Word16 -> IO Secret
readSecret IOArray Word16 Secret
secrets Word16
idx0 = do
(_, n) <- IOArray Word16 Secret -> IO (Word16, Word16)
forall i. Ix i => IOArray i Secret -> IO (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds IOArray Word16 Secret
secrets
let idx = Word16
idx0 Word16 -> Word16 -> Word16
forall a. Integral a => a -> a -> a
`mod` (Word16
n Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
1)
readArray secrets idx
readCurrentSecret :: IOArray Index Secret -> I.IORef Index -> IO (Secret, Index)
readCurrentSecret :: IOArray Word16 Secret -> IORef Word16 -> IO (Secret, Word16)
readCurrentSecret IOArray Word16 Secret
arr IORef Word16
ref = do
idx <- IORef Word16 -> IO Word16
forall a. IORef a -> IO a
I.readIORef IORef Word16
ref
sec <- readSecret arr idx
return (sec, idx)
data Secret = Secret
{ Secret -> ByteString
secretIV :: ByteString
, Secret -> ByteString
secretKey :: ByteString
, Secret -> IORef Word64
secretCounter :: I.IORef Counter
}
emptySecret :: IO Secret
emptySecret :: IO Secret
emptySecret = ByteString -> ByteString -> IORef Word64 -> Secret
Secret ByteString
BS.empty ByteString
BS.empty (IORef Word64 -> Secret) -> IO (IORef Word64) -> IO Secret
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word64 -> IO (IORef Word64)
forall a. a -> IO (IORef a)
I.newIORef Word64
0
generateSecret :: IO Secret
generateSecret :: IO Secret
generateSecret =
ByteString -> ByteString -> IORef Word64 -> Secret
Secret
(ByteString -> ByteString -> IORef Word64 -> Secret)
-> IO ByteString -> IO (ByteString -> IORef Word64 -> Secret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ByteString
genIV
IO (ByteString -> IORef Word64 -> Secret)
-> IO ByteString -> IO (IORef Word64 -> Secret)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ByteString
genKey
IO (IORef Word64 -> Secret) -> IO (IORef Word64) -> IO Secret
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Word64 -> IO (IORef Word64)
forall a. a -> IO (IORef a)
I.newIORef Word64
0
genKey :: IO ByteString
genKey :: IO ByteString
genKey = Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
keyLength
genIV :: IO ByteString
genIV :: IO ByteString
genIV = Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
ivLength
ivLength :: Int
ivLength :: Int
ivLength = Int
8
keyLength :: Int
keyLength :: Int
keyLength = Int
32
indexLength :: Int
indexLength :: Int
indexLength = Int
2
counterLength :: Int
counterLength :: Int
counterLength = Int
8
tagLength :: Int
tagLength :: Int
tagLength = Int
16
data =
{ :: Index
, :: Counter
}
encodeHeader :: Header -> IO ByteString
Header{Word16
Word64
headerIndex :: Header -> Word16
headerCounter :: Header -> Word64
headerIndex :: Word16
headerCounter :: Word64
..} = Int -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer (Int
indexLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
counterLength) ((WriteBuffer -> IO ()) -> IO ByteString)
-> (WriteBuffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> do
WriteBuffer -> Word16 -> IO ()
write16 WriteBuffer
wbuf Word16
headerIndex
WriteBuffer -> Word64 -> IO ()
write64 WriteBuffer
wbuf Word64
headerCounter
decodeHeader :: ByteString -> IO Header
ByteString
bs = ByteString -> (ReadBuffer -> IO Header) -> IO Header
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs ((ReadBuffer -> IO Header) -> IO Header)
-> (ReadBuffer -> IO Header) -> IO Header
forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf ->
Word16 -> Word64 -> Header
Header (Word16 -> Word64 -> Header) -> IO Word16 -> IO (Word64 -> Header)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word16
forall a. Readable a => a -> IO Word16
read16 ReadBuffer
rbuf IO (Word64 -> Header) -> IO Word64 -> IO Header
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReadBuffer -> IO Word64
forall a. Readable a => a -> IO Word64
read64 ReadBuffer
rbuf
newHeaderMask :: IO Header
= do
bin <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes (Int
indexLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
counterLength) :: IO ByteString
decodeHeader bin
xorHeader :: Header -> Header -> Header
Header
x Header
y =
Header
{ headerIndex :: Word16
headerIndex = Header -> Word16
headerIndex Header
x Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
`xor` Header -> Word16
headerIndex Header
y
, headerCounter :: Word64
headerCounter = Header -> Word64
headerCounter Header
x Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Header -> Word64
headerCounter Header
y
}
addHeader :: TokenManager -> Index -> Counter -> ByteString -> IO ByteString
TokenManager{IO (Secret, Word16)
Header
ThreadId
Word16 -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Word16)
getDecryptSecret :: TokenManager -> Word16 -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Word16)
getDecryptSecret :: Word16 -> IO Secret
threadId :: ThreadId
..} Word16
idx Word64
counter ByteString
cipher = do
let hdr :: Header
hdr = Word16 -> Word64 -> Header
Header Word16
idx Word64
counter
mskhdr :: Header
mskhdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
hdr
hdrbin <- Header -> IO ByteString
encodeHeader Header
mskhdr
return (hdrbin `BS.append` cipher)
delHeader
:: TokenManager -> ByteString -> IO (Maybe (Index, Counter, ByteString))
TokenManager{IO (Secret, Word16)
Header
ThreadId
Word16 -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Word16)
getDecryptSecret :: TokenManager -> Word16 -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Word16)
getDecryptSecret :: Word16 -> IO Secret
threadId :: ThreadId
..} ByteString
token
| ByteString -> Int
BS.length ByteString
token Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minlen = Maybe (Word16, Word64, ByteString)
-> IO (Maybe (Word16, Word64, ByteString))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Word16, Word64, ByteString)
forall a. Maybe a
Nothing
| Bool
otherwise = do
let (ByteString
hdrbin, ByteString
cipher) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
minlen ByteString
token
mskhdr <- ByteString -> IO Header
decodeHeader ByteString
hdrbin
let hdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
mskhdr
idx = Header -> Word16
headerIndex Header
hdr
counter = Header -> Word64
headerCounter Header
hdr
return $ Just (idx, counter, cipher)
where
minlen :: Int
minlen = Int
indexLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
counterLength
encryptToken
:: TokenManager
-> ByteString
-> IO ByteString
encryptToken :: TokenManager -> ByteString -> IO ByteString
encryptToken TokenManager
mgr ByteString
x = do
(secret, idx) <- TokenManager -> IO (Secret, Word16)
getEncryptSecret TokenManager
mgr
(counter, cipher) <- encrypt secret x
addHeader mgr idx counter cipher
encrypt
:: Secret
-> ByteString
-> IO (Counter, ByteString)
encrypt :: Secret -> ByteString -> IO (Word64, ByteString)
encrypt Secret
secret ByteString
plain = do
counter <- IORef Word64 -> (Word64 -> (Word64, Word64)) -> IO Word64
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' (Secret -> IORef Word64
secretCounter Secret
secret) (\Word64
i -> (Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1, Word64
i))
nonce <- makeNonce counter $ secretIV secret
let cipher = ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain (Secret -> ByteString
secretKey Secret
secret) ByteString
nonce
return (counter, cipher)
decryptToken
:: TokenManager
-> ByteString
-> IO (Maybe ByteString)
decryptToken :: TokenManager -> ByteString -> IO (Maybe ByteString)
decryptToken TokenManager
mgr ByteString
token = do
mx <- TokenManager
-> ByteString -> IO (Maybe (Word16, Word64, ByteString))
delHeader TokenManager
mgr ByteString
token
case mx of
Maybe (Word16, Word64, ByteString)
Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
Just (Word16
idx, Word64
counter, ByteString
cipher) -> do
secret <- TokenManager -> Word16 -> IO Secret
getDecryptSecret TokenManager
mgr Word16
idx
decrypt secret counter cipher
decrypt
:: Secret
-> Counter
-> ByteString
-> IO (Maybe ByteString)
decrypt :: Secret -> Word64 -> ByteString -> IO (Maybe ByteString)
decrypt Secret
secret Word64
counter ByteString
cipher = do
nonce <- Word64 -> ByteString -> IO ByteString
makeNonce Word64
counter (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
secretIV Secret
secret
return $ aes256gcmDecrypt cipher (secretKey secret) nonce
makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce :: Word64 -> ByteString -> IO ByteString
makeNonce Word64
counter ByteString
iv = do
cv <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
BS.create Int
ivLength ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) Word64
counter
return $ iv `BA.xor` cv
constantAdditionalData :: ByteString
constantAdditionalData :: ByteString
constantAdditionalData = ByteString
BS.empty
aes256gcmEncrypt
:: ByteString
-> ByteString
-> ByteString
-> ByteString
aes256gcmEncrypt :: ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain ByteString
key ByteString
nonce = ByteString
cipher ByteString -> ByteString -> ByteString
`BS.append` (Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
tag)
where
conn :: AES256
conn = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key) :: AES256
aeadIni :: AEAD AES256
aeadIni = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> ByteString -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall iv.
ByteArrayAccess iv =>
AEADMode -> AES256 -> iv -> CryptoFailable (AEAD AES256)
C.aeadInit AEADMode
AEAD_GCM AES256
conn ByteString
nonce
(AuthTag Bytes
tag, ByteString
cipher) = AEAD AES256
-> ByteString -> ByteString -> Int -> (AuthTag, ByteString)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
C.aeadSimpleEncrypt AEAD AES256
aeadIni ByteString
constantAdditionalData ByteString
plain Int
tagLength
aes256gcmDecrypt
:: ByteString
-> ByteString
-> ByteString
-> Maybe ByteString
aes256gcmDecrypt :: ByteString -> ByteString -> ByteString -> Maybe ByteString
aes256gcmDecrypt ByteString
ctexttag ByteString
key ByteString
nonce = do
aes <- CryptoFailable AES256 -> Maybe AES256
forall a. CryptoFailable a -> Maybe a
maybeCryptoError (CryptoFailable AES256 -> Maybe AES256)
-> CryptoFailable AES256 -> Maybe AES256
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key :: Maybe AES256
aead <- maybeCryptoError $ C.aeadInit AEAD_GCM aes nonce
let (ctext, tag) = BS.splitAt (BS.length ctexttag - tagLength) ctexttag
authtag = Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
tag
C.aeadSimpleDecrypt aead constantAdditionalData ctext authtag