module Vectorise.Monad (
  module Vectorise.Monad.Base,
  module Vectorise.Monad.Naming,
  module Vectorise.Monad.Local,
  module Vectorise.Monad.Global,
  module Vectorise.Monad.InstEnv,
  initV,

  -- * Builtins
  liftBuiltinDs,
  builtin,
  builtins,

  -- * Variables
  lookupVar,
  lookupVar_maybe,
  addGlobalParallelVar,
  addGlobalParallelTyCon,
) where

import Vectorise.Monad.Base
import Vectorise.Monad.Naming
import Vectorise.Monad.Local
import Vectorise.Monad.Global
import Vectorise.Monad.InstEnv
import Vectorise.Builtins
import Vectorise.Env

import CoreSyn
import TcRnMonad
import DsMonad
import HscTypes hiding ( MonadThings(..) )
import DynFlags
import MonadUtils (liftIO)
import InstEnv
import Class
import TyCon
import NameSet
import VarSet
import VarEnv
import Var
import Id
import Name
import ErrUtils
import Outputable
import Module

import Control.Monad (join)

-- |Run a vectorisation computation.
--
initV :: HscEnv
      -> ModGuts
      -> VectInfo
      -> VM a
      -> IO (Maybe (VectInfo, a))
initV hsc_env guts info thing_inside
  = do { dumpIfVtTrace "Incoming VectInfo" (ppr info)

       ; (_, res) <- initDsWithModGuts hsc_env guts go
       ; case join res of
           Nothing
             -> dumpIfVtTrace "Vectorisation FAILED!" empty
           Just (info', _)
             -> dumpIfVtTrace "Outgoing VectInfo" (ppr info')

       ; return $ join res
       }
  where
    dflags = hsc_dflags hsc_env

    dumpIfVtTrace = dumpIfSet_dyn dflags Opt_D_dump_vt_trace

    bindsToIds (NonRec v _)   = [v]
    bindsToIds (Rec    binds) = map fst binds

    ids = concatMap bindsToIds (mg_binds guts)

    go
      = do {   -- set up tables of builtin entities
           ; builtins        <- initBuiltins
           ; builtin_vars    <- initBuiltinVars builtins

               -- set up class and type family envrionments
           ; eps <- liftIO $ hscEPS hsc_env
           ; let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
                 instEnvs    = InstEnvs (eps_inst_env     eps)
                                        (mg_inst_env     guts)
                                        (mkModuleSet (dep_orphs (mg_deps guts)))
                 builtin_pas = initClassDicts instEnvs (paClass builtins)  -- grab all 'PA' and..
                 builtin_prs = initClassDicts instEnvs (prClass builtins)  -- ..'PR' class instances

               -- construct the initial global environment
           ; let genv = extendImportedVarsEnv builtin_vars
                        . setPAFunsEnv        builtin_pas
                        . setPRFunsEnv        builtin_prs
                        $ initGlobalEnv (gopt Opt_VectorisationAvoidance dflags)
                                        info (mg_vect_decls guts) instEnvs famInstEnvs

               -- perform vectorisation
           ; r <- runVM thing_inside builtins genv emptyLocalEnv
           ; case r of
               Yes genv _ x -> return $ Just (new_info genv, x)
               No reason    -> do { unqual <- mkPrintUnqualifiedDs
                                  ; liftIO $
                                      printOutputForUser dflags unqual $
                                        mkDumpDoc "Warning: vectorisation failure:" reason
                                  ; return Nothing
                                  }
           }

    new_info genv = modVectInfo genv ids (mg_tcs guts) (mg_vect_decls guts) info

    -- For a given DPH class, produce a mapping from type constructor (in head position) to the
    -- instance dfun for that type constructor and class.  (DPH class instances cannot overlap in
    -- head constructors.)
    --
    initClassDicts :: InstEnvs -> Class -> [(Name, Var)]
    initClassDicts insts cls = map find $ classInstances insts cls
      where
        find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
               | otherwise                       = pprPanic invalidInstance (ppr i)

    invalidInstance = "Invalid DPH instance (overlapping in head constructor)"

-- Builtins -------------------------------------------------------------------

-- |Lift a desugaring computation using the `Builtins` into the vectorisation monad.
--
liftBuiltinDs :: (Builtins -> DsM a) -> VM a
liftBuiltinDs p = VM $ \bi genv lenv -> do { x <- p bi; return (Yes genv lenv x)}

-- |Project something from the set of builtins.
--
builtin :: (Builtins -> a) -> VM a
builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi))

-- |Lift a function using the `Builtins` into the vectorisation monad.
--
builtins :: (a -> Builtins -> b) -> VM (a -> b)
builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi))


-- Var ------------------------------------------------------------------------

-- |Lookup the vectorised, and if local, also the lifted version of a variable.
--
-- * If it's in the global environment we get the vectorised version.
-- * If it's in the local environment we get both the vectorised and lifted version.
--
lookupVar :: Var -> VM (Scope Var (Var, Var))
lookupVar v
  = do { mb_res <- lookupVar_maybe v
       ; case mb_res of
           Just x  -> return x
           Nothing ->
               do dflags <- getDynFlags
                  dumpVar dflags v
       }

lookupVar_maybe :: Var -> VM (Maybe (Scope Var (Var, Var)))
lookupVar_maybe v
 = do { r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
      ; case r of
          Just e  -> return $ Just (Local e)
          Nothing -> fmap Global <$> (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
      }

dumpVar :: DynFlags -> Var -> a
dumpVar dflags var
  | Just _    <- isClassOpId_maybe var
  = cantVectorise dflags "ClassOpId not vectorised:" (ppr var)
  | otherwise
  = cantVectorise dflags "Variable not vectorised:" (ppr var)


-- Global parallel entities ----------------------------------------------------

-- |Mark the given variable as parallel — i.e., executing the associated code might involve
-- parallel array computations.
--
addGlobalParallelVar :: Var -> VM ()
addGlobalParallelVar var
  = do { traceVt "addGlobalParallelVar" (ppr var)
       ; updGEnv $ \env -> env{global_parallel_vars = extendDVarSet (global_parallel_vars env) var}
       }

-- |Mark the given type constructor as parallel — i.e., its values might embed parallel arrays.
--
addGlobalParallelTyCon :: TyCon -> VM ()
addGlobalParallelTyCon tycon
  = do { traceVt "addGlobalParallelTyCon" (ppr tycon)
       ; updGEnv $ \env ->
           env{global_parallel_tycons = extendNameSet (global_parallel_tycons env) (tyConName tycon)}
       }