never executed always true always false
    1 -- All extensions should be enabled explicitly due to doctest in this module.
    2 {-# LANGUAGE DeriveGeneric #-}
    3 {-# LANGUAGE FlexibleContexts #-}
    4 {-# LANGUAGE FlexibleInstances #-}
    5 {-# LANGUAGE FunctionalDependencies #-}
    6 {-# LANGUAGE ImportQualifiedPost #-}
    7 {-# LANGUAGE LambdaCase #-}
    8 {-# LANGUAGE NamedFieldPuns #-}
    9 
   10 {- |
   11 Module      : NITTA.Model.Problems.Refactor.ConstantFolding
   12 Description : Constant folding optimization
   13 Copyright   : (c) Daniil Prohorov, 2021
   14 License     : BSD3
   15 Maintainer  : aleksandr.penskoi@gmail.com
   16 Stability   : experimental
   17 
   18 == ASCII digram
   19 
   20 Before compile-time eval optimization
   21 
   22 @
   23     +------------------+
   24     |                  |         +-------------------------+
   25     | Constant 2 "a"   |         |                         |
   26     |                  +-------->+                         |      +--------------+
   27     +------------------+         |                         |      |              |
   28                                  |   Add "a" "b" ["res"]   +----->+    ......    |
   29     +------------------+         |                         |      |              |
   30     |                  +-------->+                         |      +--------------+
   31     | Constant 3 "b1"  |         |                         |
   32     |                  |         +-------------------------+
   33     +------------------+
   34 @
   35 
   36 After compile-time eval optimization
   37 
   38 @
   39     +------------------+         +--------------+
   40     |                  |         |              |
   41     | Constant 5 "res" +-------->+    ......    |
   42     |                  |         |              |
   43     +------------------+         +--------------+
   44 @
   45 
   46 == Example from ASCII diagram
   47 
   48 >>> let a = constant 1 ["a"]
   49 >>> let b = constant 2 ["b"]
   50 >>> let res = add "a" "b" ["res"]
   51 >>> loopRes = loop 1 "e" ["res"]
   52 >>> let fs = [a, b, res, loopRes] :: [F String Int]
   53 >>> constantFoldingDecision fs $ head $ constantFoldingOptions fs
   54 [loop(1, e) = res,const(3) = res]
   55 -}
   56 module NITTA.Model.Problems.Refactor.ConstantFolding (
   57     ConstantFolding (..),
   58     ConstantFoldingProblem (..),
   59 ) where
   60 
   61 import Data.Default
   62 import Data.HashMap.Strict qualified as HM
   63 import Data.List qualified as L
   64 import Data.Set qualified as S
   65 import GHC.Generics
   66 import NITTA.Intermediate.Functions
   67 import NITTA.Intermediate.Types
   68 
   69 data ConstantFolding v x = ConstantFolding
   70     { cRefOld :: [F v x]
   71     , cRefNew :: [F v x]
   72     }
   73     deriving (Generic, Show, Eq)
   74 
   75 class ConstantFoldingProblem u v x | u -> v x where
   76     -- | Function takes algorithm in 'DataFlowGraph' and return list of optimizations that can be done
   77     constantFoldingOptions :: u -> [ConstantFolding v x]
   78     constantFoldingOptions _ = []
   79 
   80     -- | Function takes 'ConstantFolding' and modify 'DataFlowGraph'
   81     constantFoldingDecision :: u -> ConstantFolding v x -> u
   82     constantFoldingDecision _ _ = error "not implemented"
   83 
   84 instance (Var v, Val x) => ConstantFoldingProblem [F v x] v x where
   85     constantFoldingOptions fs =
   86         let clusters = selectClusters fs
   87             evaluatedClusters = map evalCluster clusters
   88             zipOfClusters = zip clusters evaluatedClusters
   89             filteredZip = filter (\case ([_], _) -> False; _ -> True) zipOfClusters
   90             options = [ConstantFolding{cRefOld = c, cRefNew = ec} | (c, ec) <- filteredZip, c /= ec]
   91             optionsFiltered = filter isBlankOptions options
   92             isBlankOptions = not . null . constantFoldingDecision fs
   93          in optionsFiltered
   94 
   95     constantFoldingDecision fs ConstantFolding{cRefOld, cRefNew}
   96         | cRefOld == cRefNew = cRefNew
   97         | otherwise = deleteExtraF $ (fs L.\\ cRefOld) <> cRefNew
   98 
   99 selectClusters fs =
  100     let consts = filter isConst fs
  101         isIntersection a b = not . S.null $ S.intersection a b
  102         inputsAreConst f = inputs f `S.isSubsetOf` S.unions (map outputs consts)
  103         getInputConsts f = filter (\c -> outputs c `isIntersection` inputs f) consts
  104         createCluster f
  105             | inputsAreConst f = f : getInputConsts f
  106             | otherwise = [f]
  107      in map createCluster fs
  108 
  109 evalCluster [f] = [f]
  110 evalCluster fs = outputResult
  111     where
  112         (consts, fSingleton) = L.partition isConst fs
  113         f = case fSingleton of
  114             [f'] -> f'
  115             _ -> error "evalCluster: internal error"
  116         cntx = CycleCntx $ HM.fromList $ concatMap (simulate def) consts
  117         outputResult
  118             | null $ outputs f = fs
  119             | otherwise = map (\(v, x) -> constant x [v]) (simulate cntx f) <> consts
  120 
  121 deleteExtraF fs =
  122     L.nub
  123         [ f1
  124         | f1 <- fs
  125         , f2 <- fs
  126         , f1 /= f2
  127         , not $ null (variables f1 `S.intersection` variables f2)
  128         ]