{-# LANGUAGE GADTs #-}

-- | This module defines the multi-thread backend for the `Network` monad.
module Choreography.Network.Local where

import Choreography.Location
import Choreography.Network
import Control.Concurrent
import Control.Concurrent.Chan
import Control.Monad
import Control.Monad.Freer
import Control.Monad.IO.Class
import Data.HashMap.Strict (HashMap, (!))
import Data.HashMap.Strict qualified as HashMap

-- | Each location is associated with a message buffer which stores messages sent
-- from other locations.
type MsgBuf = HashMap LocTm (Chan String)

newtype LocalConfig = LocalConfig
  { LocalConfig -> HashMap LocTm MsgBuf
locToBuf :: HashMap LocTm MsgBuf
  }

newEmptyMsgBuf :: [LocTm] -> IO MsgBuf
newEmptyMsgBuf :: [LocTm] -> IO MsgBuf
newEmptyMsgBuf = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {k} {a}.
Hashable k =>
HashMap k (Chan a) -> k -> IO (HashMap k (Chan a))
f forall k v. HashMap k v
HashMap.empty
  where
    f :: HashMap k (Chan a) -> k -> IO (HashMap k (Chan a))
f HashMap k (Chan a)
hash k
loc = do
      Chan a
chan <- forall a. IO (Chan a)
newChan
      forall (m :: * -> *) a. Monad m => a -> m a
return (forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HashMap.insert k
loc Chan a
chan HashMap k (Chan a)
hash)

mkLocalConfig :: [LocTm] -> IO LocalConfig
mkLocalConfig :: [LocTm] -> IO LocalConfig
mkLocalConfig [LocTm]
locs = HashMap LocTm MsgBuf -> LocalConfig
LocalConfig forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM HashMap LocTm MsgBuf -> LocTm -> IO (HashMap LocTm MsgBuf)
f forall k v. HashMap k v
HashMap.empty [LocTm]
locs
  where
    f :: HashMap LocTm MsgBuf -> LocTm -> IO (HashMap LocTm MsgBuf)
f HashMap LocTm MsgBuf
hash LocTm
loc = do
      MsgBuf
buf <- [LocTm] -> IO MsgBuf
newEmptyMsgBuf [LocTm]
locs
      forall (m :: * -> *) a. Monad m => a -> m a
return (forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HashMap.insert LocTm
loc MsgBuf
buf HashMap LocTm MsgBuf
hash)

locs :: LocalConfig -> [LocTm]
locs :: LocalConfig -> [LocTm]
locs = forall k v. HashMap k v -> [k]
HashMap.keys forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalConfig -> HashMap LocTm MsgBuf
locToBuf

runNetworkLocal :: MonadIO m => LocalConfig -> LocTm -> Network m a -> m a
runNetworkLocal :: forall (m :: * -> *) a.
MonadIO m =>
LocalConfig -> LocTm -> Network m a -> m a
runNetworkLocal LocalConfig
cfg LocTm
self Network m a
prog = forall (m :: * -> *) (f :: * -> *) a.
Monad m =>
(forall a1. f a1 -> m a1) -> Freer f a -> m a
interpFreer forall (m :: * -> *) a. MonadIO m => NetworkSig m a -> m a
handler Network m a
prog
  where
    handler :: MonadIO m => NetworkSig m a -> m a
    handler :: forall (m :: * -> *) a. MonadIO m => NetworkSig m a -> m a
handler (Run m a
m)    = m a
m
    handler (Send a
a LocTm
l) = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Chan a -> a -> IO ()
writeChan ((LocalConfig -> HashMap LocTm MsgBuf
locToBuf LocalConfig
cfg forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
! LocTm
l) forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
! LocTm
self) (forall a. Show a => a -> LocTm
show a
a)
    handler (Recv LocTm
l)   = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Read a => LocTm -> a
read forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Chan a -> IO a
readChan ((LocalConfig -> HashMap LocTm MsgBuf
locToBuf LocalConfig
cfg forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
! LocTm
self) forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
! LocTm
l)
    handler(BCast a
a)   = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *) a. MonadIO m => NetworkSig m a -> m a
handler forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a (m :: * -> *). Show a => a -> LocTm -> NetworkSig m ()
Send a
a) (LocalConfig -> [LocTm]
locs LocalConfig
cfg)

instance Backend LocalConfig where
  runNetwork :: forall (m :: * -> *) a.
MonadIO m =>
LocalConfig -> LocTm -> Network m a -> m a
runNetwork = forall (m :: * -> *) a.
MonadIO m =>
LocalConfig -> LocTm -> Network m a -> m a
runNetworkLocal