never executed always true always false
1 -- All extensions should be enabled explicitly due to doctest in this module.
2 {-# LANGUAGE DeriveGeneric #-}
3 {-# LANGUAGE FlexibleContexts #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE FunctionalDependencies #-}
6 {-# LANGUAGE ImportQualifiedPost #-}
7 {-# LANGUAGE LambdaCase #-}
8 {-# LANGUAGE NamedFieldPuns #-}
9
10 {- |
11 Module : NITTA.Model.Problems.Refactor.ConstantFolding
12 Description : Constant folding optimization
13 Copyright : (c) Daniil Prohorov, 2021
14 License : BSD3
15 Maintainer : aleksandr.penskoi@gmail.com
16 Stability : experimental
17
18 == ASCII digram
19
20 Before compile-time eval optimization
21
22 @
23 +------------------+
24 | | +-------------------------+
25 | Constant 2 "a" | | |
26 | +-------->+ | +--------------+
27 +------------------+ | | | |
28 | Add "a" "b" ["res"] +----->+ ...... |
29 +------------------+ | | | |
30 | +-------->+ | +--------------+
31 | Constant 3 "b1" | | |
32 | | +-------------------------+
33 +------------------+
34 @
35
36 After compile-time eval optimization
37
38 @
39 +------------------+ +--------------+
40 | | | |
41 | Constant 5 "res" +-------->+ ...... |
42 | | | |
43 +------------------+ +--------------+
44 @
45
46 == Example from ASCII diagram
47
48 >>> let a = constant 1 ["a"]
49 >>> let b = constant 2 ["b"]
50 >>> let res = add "a" "b" ["res"]
51 >>> loopRes = loop 1 "e" ["res"]
52 >>> let fs = [a, b, res, loopRes] :: [F String Int]
53 >>> constantFoldingDecision fs $ head $ constantFoldingOptions fs
54 [loop(1, e) = res,const(3) = res]
55 -}
56 module NITTA.Model.Problems.Refactor.ConstantFolding (
57 ConstantFolding (..),
58 ConstantFoldingProblem (..),
59 ) where
60
61 import Data.Default
62 import Data.HashMap.Strict qualified as HM
63 import Data.List qualified as L
64 import Data.Set qualified as S
65 import GHC.Generics
66 import NITTA.Intermediate.Functions
67 import NITTA.Intermediate.Types
68
69 data ConstantFolding v x = ConstantFolding
70 { cRefOld :: [F v x]
71 , cRefNew :: [F v x]
72 }
73 deriving (Generic, Show, Eq)
74
75 class ConstantFoldingProblem u v x | u -> v x where
76 -- | Function takes algorithm in 'DataFlowGraph' and return list of optimizations that can be done
77 constantFoldingOptions :: u -> [ConstantFolding v x]
78 constantFoldingOptions _ = []
79
80 -- | Function takes 'ConstantFolding' and modify 'DataFlowGraph'
81 constantFoldingDecision :: u -> ConstantFolding v x -> u
82 constantFoldingDecision _ _ = error "not implemented"
83
84 instance (Var v, Val x) => ConstantFoldingProblem [F v x] v x where
85 constantFoldingOptions fs =
86 let clusters = selectClusters fs
87 evaluatedClusters = map evalCluster clusters
88 zipOfClusters = zip clusters evaluatedClusters
89 filteredZip = filter (\case ([_], _) -> False; _ -> True) zipOfClusters
90 options = [ConstantFolding{cRefOld = c, cRefNew = ec} | (c, ec) <- filteredZip, c /= ec]
91 optionsFiltered = filter isBlankOptions options
92 isBlankOptions = not . null . constantFoldingDecision fs
93 in optionsFiltered
94
95 constantFoldingDecision fs ConstantFolding{cRefOld, cRefNew}
96 | cRefOld == cRefNew = cRefNew
97 | otherwise = deleteExtraF $ (fs L.\\ cRefOld) <> cRefNew
98
99 selectClusters fs =
100 let consts = filter isConst fs
101 isIntersection a b = not . S.null $ S.intersection a b
102 inputsAreConst f = inputs f `S.isSubsetOf` S.unions (map outputs consts)
103 getInputConsts f = filter (\c -> outputs c `isIntersection` inputs f) consts
104 createCluster f
105 | inputsAreConst f = f : getInputConsts f
106 | otherwise = [f]
107 in map createCluster fs
108
109 evalCluster [f] = [f]
110 evalCluster fs = outputResult
111 where
112 (consts, fSingleton) = L.partition isConst fs
113 f = case fSingleton of
114 [f'] -> f'
115 _ -> error "evalCluster: internal error"
116 cntx = CycleCntx $ HM.fromList $ concatMap (simulate def) consts
117 outputResult
118 | null $ outputs f = fs
119 | otherwise = map (\(v, x) -> constant x [v]) (simulate cntx f) <> consts
120
121 deleteExtraF fs =
122 L.nub
123 [ f1
124 | f1 <- fs
125 , f2 <- fs
126 , f1 /= f2
127 , not $ null (variables f1 `S.intersection` variables f2)
128 ]