{-# LANGUAGE TemplateHaskell #-}

module CodeGen.Llvm.Runner (run, compileToBinary, compileToLlvmIr) where

import CodeGen.Llvm.LlvmIrGen (genLlvmIrModule, ppLlvmModule)
import CodeGen.Module (compileToModule)
import qualified CodeGen.RunResult as RR
import CodeGen.TimedValue (TimedValue (TimedValue), measureTimedValue)
import Control.Exception (bracket)
import Control.Monad.Except (Except, runExcept)
import Data.FileEmbed (embedFile, makeRelativeToProject)
import Data.String.Conversions (cs)
import Data.Text (Text)
import qualified Data.Text.Encoding as Txt
import qualified Data.Text.IO as Txt
import System.Directory (removePathForcibly, withCurrentDirectory)
import System.Exit (ExitCode (..))
import System.IO (IOMode (WriteMode), hClose, withFile)
import System.Posix.Temp (mkdtemp, mkstemps)
import System.Process (callProcess, readProcessWithExitCode)

run :: Text -> IO RR.RunResult
run :: Text -> IO RunResult
run Text
text = do
  TimedValue Either Text ()
compResult Nanoseconds
compTime <- Text -> String -> IO (TimedValue (Either Text ()))
compileToBinary Text
text String
outputFilePath

  case Either Text ()
compResult of
    Right () -> do
      TimedValue Either (String, String, Int) String
runResult Nanoseconds
runTime <- IO (TimedValue (Either (String, String, Int) String))
runCompiledModule
      RunResult -> IO RunResult
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RunResult -> IO RunResult) -> RunResult -> IO RunResult
forall a b. (a -> b) -> a -> b
$ case Either (String, String, Int) String
runResult of
        Right String
out ->
          RR.Success
            { stdout :: Text
RR.stdout = String -> Text
forall a b. ConvertibleStrings a b => a -> b
cs String
out,
              compTime :: Nanoseconds
RR.compTime = Nanoseconds
compTime,
              runTime :: Nanoseconds
RR.runTime = Nanoseconds
runTime
            }
        Left (String
out, String
err, Int
code) ->
          RR.RuntimeError
            { stdout :: Text
RR.stdout = String -> Text
forall a b. ConvertibleStrings a b => a -> b
cs String
out,
              stderr :: Text
RR.stderr = String -> Text
forall a b. ConvertibleStrings a b => a -> b
cs String
err,
              exitCode :: Int
RR.exitCode = Int
code,
              compTime :: Nanoseconds
RR.compTime = Nanoseconds
compTime,
              runTime :: Nanoseconds
RR.runTime = Nanoseconds
runTime
            }
    Left Text
compErrMsg ->
      RunResult -> IO RunResult
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RunResult -> IO RunResult) -> RunResult -> IO RunResult
forall a b. (a -> b) -> a -> b
$
        RR.CompilationError
          { compErrMsg :: Text
RR.compErrMsg = Text
compErrMsg,
            compTime :: Nanoseconds
RR.compTime = Nanoseconds
compTime
          }
  where
    outputFilePath :: FilePath
    outputFilePath :: String
outputFilePath = String
"./program"

    runCompiledModule :: IO (TimedValue (Either (String, String, Int) String))
    runCompiledModule :: IO (TimedValue (Either (String, String, Int) String))
runCompiledModule = do
      TimedValue (Either (String, String, Int) String)
measuredResult <- IO (Either (String, String, Int) String)
-> IO (TimedValue (Either (String, String, Int) String))
forall a. IO a -> IO (TimedValue a)
measureTimedValue (IO (Either (String, String, Int) String)
 -> IO (TimedValue (Either (String, String, Int) String)))
-> IO (Either (String, String, Int) String)
-> IO (TimedValue (Either (String, String, Int) String))
forall a b. (a -> b) -> a -> b
$ do
        (ExitCode
exitCode, String
stdout, String
stderr) <- String -> [String] -> String -> IO (ExitCode, String, String)
readProcessWithExitCode String
outputFilePath [] []
        Either (String, String, Int) String
-> IO (Either (String, String, Int) String)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (String, String, Int) String
 -> IO (Either (String, String, Int) String))
-> Either (String, String, Int) String
-> IO (Either (String, String, Int) String)
forall a b. (a -> b) -> a -> b
$ case ExitCode
exitCode of
          ExitCode
ExitSuccess -> String -> Either (String, String, Int) String
forall a b. b -> Either a b
Right String
stdout
          ExitFailure Int
ec -> (String, String, Int) -> Either (String, String, Int) String
forall a b. a -> Either a b
Left (String
stdout, String
stderr, Int
ec)

      String -> IO ()
removePathForcibly String
outputFilePath

      TimedValue (Either (String, String, Int) String)
-> IO (TimedValue (Either (String, String, Int) String))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return TimedValue (Either (String, String, Int) String)
measuredResult

compileToBinary :: Text -> FilePath -> IO (TimedValue (Either Text ()))
compileToBinary :: Text -> String -> IO (TimedValue (Either Text ()))
compileToBinary Text
text String
outputFilePath = IO (Either Text ()) -> IO (TimedValue (Either Text ()))
forall a. IO a -> IO (TimedValue a)
measureTimedValue (IO (Either Text ()) -> IO (TimedValue (Either Text ())))
-> IO (Either Text ()) -> IO (TimedValue (Either Text ()))
forall a b. (a -> b) -> a -> b
$
  Either Text (IO ()) -> IO (Either Text ())
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a.
Applicative f =>
Either Text (f a) -> f (Either Text a)
sequenceA (Either Text (IO ()) -> IO (Either Text ()))
-> Either Text (IO ()) -> IO (Either Text ())
forall a b. (a -> b) -> a -> b
$
    Except Text (IO ()) -> Either Text (IO ())
forall e a. Except e a -> Either e a
runExcept (Except Text (IO ()) -> Either Text (IO ()))
-> Except Text (IO ()) -> Either Text (IO ())
forall a b. (a -> b) -> a -> b
$ do
      Text
llvmIrText <- Text -> Except Text Text
compileToLlvmIr' Text
text
      IO () -> Except Text (IO ())
forall a. a -> ExceptT Text Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO () -> Except Text (IO ())) -> IO () -> Except Text (IO ())
forall a b. (a -> b) -> a -> b
$
        IO String -> (String -> IO ()) -> (String -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (String -> IO String
mkdtemp String
"build") String -> IO ()
removePathForcibly ((String -> IO ()) -> IO ()) -> (String -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \String
buildDir ->
          String -> IO () -> IO ()
forall a. String -> IO a -> IO a
withCurrentDirectory String
buildDir (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            (String
llvm, Handle
llvmHandle) <- String -> String -> IO (String, Handle)
mkstemps String
"module" String
".ll"
            Handle -> Text -> IO ()
Txt.hPutStrLn Handle
llvmHandle Text
llvmIrText
            Handle -> IO ()
hClose Handle
llvmHandle

            (String
runtime, Handle
runtimeHandle) <- String -> String -> IO (String, Handle)
mkstemps String
"runtime" String
".c"
            let runtimeFileText :: Text
runtimeFileText = ByteString -> Text
Txt.decodeUtf8 $(makeRelativeToProject "lib/CodeGen/Runtime/runtime.c" >>= embedFile)
            Handle -> Text -> IO ()
Txt.hPutStrLn Handle
runtimeHandle Text
runtimeFileText
            Handle -> IO ()
hClose Handle
runtimeHandle

            String -> [String] -> IO ()
callProcess String
"clang" [String
"-Wno-override-module", String
"-O3", String
"-lm", String
llvm, String
runtime, String
"-o", String
"../" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
outputFilePath]

compileToLlvmIr :: Text -> FilePath -> IO (TimedValue (Either Text ()))
compileToLlvmIr :: Text -> String -> IO (TimedValue (Either Text ()))
compileToLlvmIr Text
text String
outputFilePath = IO (Either Text ()) -> IO (TimedValue (Either Text ()))
forall a. IO a -> IO (TimedValue a)
measureTimedValue (IO (Either Text ()) -> IO (TimedValue (Either Text ())))
-> IO (Either Text ()) -> IO (TimedValue (Either Text ()))
forall a b. (a -> b) -> a -> b
$
  Either Text (IO ()) -> IO (Either Text ())
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a.
Applicative f =>
Either Text (f a) -> f (Either Text a)
sequenceA (Either Text (IO ()) -> IO (Either Text ()))
-> Either Text (IO ()) -> IO (Either Text ())
forall a b. (a -> b) -> a -> b
$
    Except Text (IO ()) -> Either Text (IO ())
forall e a. Except e a -> Either e a
runExcept (Except Text (IO ()) -> Either Text (IO ()))
-> Except Text (IO ()) -> Either Text (IO ())
forall a b. (a -> b) -> a -> b
$ do
      Text
llvmIrText <- Text -> Except Text Text
compileToLlvmIr' Text
text
      IO () -> Except Text (IO ())
forall a. a -> ExceptT Text Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO () -> Except Text (IO ())) -> IO () -> Except Text (IO ())
forall a b. (a -> b) -> a -> b
$
        String -> IOMode -> (Handle -> IO ()) -> IO ()
forall r. String -> IOMode -> (Handle -> IO r) -> IO r
withFile String
outputFilePath IOMode
WriteMode ((Handle -> IO ()) -> IO ()) -> (Handle -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
          Handle -> Text -> IO ()
Txt.hPutStrLn Handle
handle Text
llvmIrText

-- * Internal

compileToLlvmIr' :: Text -> Except Text Text
compileToLlvmIr' :: Text -> Except Text Text
compileToLlvmIr' Text
text = do
  Module
irModule <- Text -> Except Text Module
compileToModule Text
text
  Text -> Except Text Text
forall a. a -> ExceptT Text Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> Except Text Text) -> Text -> Except Text Text
forall a b. (a -> b) -> a -> b
$ Module -> Text
ppLlvmModule (Module -> Text) -> Module -> Text
forall a b. (a -> b) -> a -> b
$ Module -> Module
genLlvmIrModule Module
irModule