never executed always true always false
1 -- All extensions should be enabled explicitly due to doctest in this module.
2 {-# LANGUAGE ConstraintKinds #-}
3 {-# LANGUAGE DeriveGeneric #-}
4 {-# LANGUAGE FlexibleContexts #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE FunctionalDependencies #-}
7 {-# LANGUAGE GADTs #-}
8 {-# LANGUAGE ImportQualifiedPost #-}
9 {-# LANGUAGE LambdaCase #-}
10 {-# LANGUAGE NamedFieldPuns #-}
11 {-# LANGUAGE TypeFamilies #-}
12 {-# LANGUAGE UndecidableInstances #-}
13
14 {- |
15 The optimization consists of two parts:
16 1) Replacing logic functions with lookup tables
17 2) Searching and merging lookup tables, if possible
18
19 >>> let a = constant 1 ["a"]
20 >>> let b = constant 2 ["b"]
21 >>> let c = constant 3 ["c"]
22 >>> let f1 = logicAnd "a" "b" ["f1"]
23 >>> let f2 = logicOr "f1" "c" ["f2"]
24 >>> let loopRes = loop 1 "e" ["f2"]
25 >>> let fs = [a, b, c, f1, f2, loopRes] :: [F String Int]
26 >>> optimizeLogicalUnitDecision fs $ head $ optimizeLogicalUnitOptions fs
27 [const(1) = a,const(2) = b,const(3) = c,loop(1, e) = f2,TruthTable fromList [([False,False,False],False),([False,False,True],True),([False,True,False],False),([False,True,True],True),([True,False,False],False),([True,False,True],True),([True,True,False],True),([True,True,True],True)] [a,b,c] = f2]
28 -}
29 module NITTA.Model.Problems.Refactor.OptimizeLogicalUnit (
30 OptimizeLogicalUnit (..),
31 OptimizeLogicalUnitProblem (..),
32 )
33 where
34
35 import Control.Monad (replicateM)
36 import Data.Foldable (foldl')
37 import Data.List qualified as L
38 import Data.Map qualified as M
39 import Data.Maybe
40 import Data.Set qualified as S
41 import GHC.Generics
42 import NITTA.Intermediate.Functions
43 import NITTA.Intermediate.Types
44
45 data OptimizeLogicalUnit v x = OptimizeLogicalUnit
46 { rOld :: [F v x]
47 , rNew :: [F v x]
48 }
49 deriving (Generic, Show, Eq)
50
51 class OptimizeLogicalUnitProblem u v x | u -> v x where
52 optimizeLogicalUnitOptions :: u -> [OptimizeLogicalUnit v x]
53 optimizeLogicalUnitOptions _ = []
54
55 -- | Function takes 'OptimizeLogicalUnit' and modify 'DataFlowGraph'
56 optimizeLogicalUnitDecision :: u -> OptimizeLogicalUnit v x -> u
57 optimizeLogicalUnitDecision _ _ = error "not implemented"
58
59 instance (Var v, Val x) => OptimizeLogicalUnitProblem [F v x] v x where
60 optimizeLogicalUnitOptions fs =
61 let supportedFunctions = filter isSupportedByLogicalUnit fs
62
63 rNew =
64 if not (null supportedFunctions)
65 && isOptimizationNeeded supportedFunctions
66 then optimizeCluster supportedFunctions fs
67 else []
68 result =
69 [ OptimizeLogicalUnit{rOld = supportedFunctions, rNew}
70 | not (null rNew) && S.fromList supportedFunctions /= S.fromList rNew
71 ]
72 in result
73
74 optimizeLogicalUnitDecision fs OptimizeLogicalUnit{rOld, rNew} =
75 deleteExtraLogicalUnits $ (fs L.\\ rOld) <> rNew
76
77 deleteExtraLogicalUnits fs =
78 L.nub
79 [ f1
80 | f1 <- fs
81 , f2 <- fs
82 , f1 /= f2
83 , not $ S.null (variables f1 `S.intersection` variables f2)
84 ]
85
86 isOptimizationNeeded fs = countLogicalUnits fs > 1 || hasLogicFunctions fs
87 where
88 hasLogicFunctions fns = any isSupportedByLogicalUnit fns
89
90 isLogicalUnit f = case castF f of
91 Just (TruthTable{}) -> True
92 _ -> False
93
94 countLogicalUnits fns = length $ filter isLogicalUnit fns
95
96 isSupportedByLogicalUnit f
97 | Just LogicAnd{} <- castF f = True
98 | Just LogicOr{} <- castF f = True
99 | Just LogicNot{} <- castF f = True
100 | otherwise = False
101
102 optimizeCluster allFunctions _ =
103 let clusters = findMergeClusters allFunctions
104 mergedLogicalUnits = mapMaybe mergeCluster clusters
105
106 singleFunctions = filter (\f -> isSupportedByLogicalUnit f && S.size (outputs f) /= 1) allFunctions
107 singleLogicalUnits = mapMaybe convertToLOGICALUNIT singleFunctions
108
109 remainingFunctions = allFunctions L.\\ (concat clusters ++ singleFunctions)
110 in mergedLogicalUnits ++ singleLogicalUnits ++ remainingFunctions
111 where
112 mergeCluster cluster
113 | isSingleOutputChain cluster = mergeLogicCluster M.empty cluster
114 | otherwise = Nothing
115
116 convertToLOGICALUNIT f = case castF f of
117 Just (LogicAnd (I a) (I b) (O out)) ->
118 buildCombinedLOGICALUNIT
119 [a, b]
120 out
121 ( \case
122 [x, y] -> x && y
123 _ -> error "Unexpected pattern"
124 )
125 Just (LogicOr (I a) (I b) (O out)) ->
126 buildCombinedLOGICALUNIT
127 [a, b]
128 out
129 ( \case
130 [x, y] -> x || y
131 _ -> error "Unexpected pattern"
132 )
133 Just (LogicNot (I a) (O out)) ->
134 buildCombinedLOGICALUNIT
135 [a]
136 out
137 ( \case
138 [x] -> not x
139 _ -> error "Unexpected pattern"
140 )
141 _ -> Nothing
142
143 mergeLogicCluster _ fs =
144 let (inputVars, finalOutput) = analyzeClusterIO fs
145 evalFn = buildCombinedLogic fs inputVars
146 in buildCombinedLOGICALUNIT inputVars finalOutput evalFn
147
148 isSingleOutputChain fs =
149 all (\f -> S.size (outputs f) == 1) fs
150 && all (== 1) [S.size (outputs (fs !! i) `S.intersection` inputs (fs !! (i + 1))) | i <- [0 .. length fs - 2]]
151
152 analyzeClusterIO fs =
153 let allInputs = S.unions $ map inputs fs
154 allOutputs = S.unions $ map outputs fs
155 externalInputs = S.difference allInputs allOutputs
156 finalOutput = outputs $ last fs
157 in (S.toList externalInputs, finalOutput)
158
159 buildCombinedLogic fs inputVars =
160 let evalCombination comb =
161 let varMap = M.fromList $ zip inputVars comb
162 resultMap = foldl' (\vm f -> applyLogicGate f vm) varMap fs
163 in resultMap M.! S.elemAt 0 (outputs $ last fs)
164 in evalCombination
165
166 applyLogicGate f varMap = case castF f of
167 Just (LogicAnd (I a) (I b) (O out)) ->
168 case S.toList out of
169 [outVar] -> M.insert outVar (varMap M.! a && varMap M.! b) varMap
170 _ -> error "LogicAnd must have exactly one output: 1"
171 Just (LogicOr (I a) (I b) (O out)) ->
172 case S.toList out of
173 [outVar] -> M.insert outVar (varMap M.! a || varMap M.! b) varMap
174 _ -> error "LogicOr must have exactly one output: 2"
175 Just (LogicNot (I a) (O out)) ->
176 case S.toList out of
177 [outVar] -> M.insert outVar (not $ varMap M.! a) varMap
178 _ -> error "LogicNot must have exactly one output: 3"
179 _ -> varMap
180
181 buildCombinedLOGICALUNIT :: (Var v, Val x) => [v] -> S.Set v -> ([Bool] -> Bool) -> Maybe (F v x)
182 buildCombinedLOGICALUNIT inputVars outputSet evalFn =
183 let logicalunitInputs = map I inputVars
184 logicalunitOutput = O outputSet
185 inputCombinations = replicateM (length inputVars) [False, True]
186 tbl = M.fromList [(comb, evalFn comb) | comb <- inputCombinations]
187 in Just $ packF $ TruthTable tbl logicalunitInputs logicalunitOutput
188
189 topSort :: Eq a => [(a, [a])] -> [a]
190 topSort [] = []
191 topSort g = case L.partition (null . snd) g of
192 ([], _) -> []
193 (ready, rest) ->
194 map fst ready
195 ++ topSort
196 [ (x, filter (`notElem` readyNodes) ys)
197 | (x, ys) <- rest
198 ]
199 where
200 readyNodes = map fst ready
201
202 findMergeClusters :: Var v => [F v x] -> [[F v x]]
203 findMergeClusters fs =
204 let deps = buildDependencyGraph fs
205 sorted = reverse $ topSort deps
206 clusters = groupChains sorted
207 in clusters
208 where
209 buildDependencyGraph fns =
210 [ (f, [g | g <- fns, sharesDependency f g])
211 | f <- fns
212 ]
213
214 sharesDependency f g =
215 not $ S.null (outputs f `S.intersection` inputs g)
216
217 groupChains [] = []
218 groupChains (x : xs) =
219 let (chain, rest) = collectChain [x] xs
220 in chain : groupChains rest
221 where
222 collectChain acc' [] = (acc', [])
223 collectChain acc' (y : ys)
224 | sharesDependency (last acc') y
225 && isSingleOutputChain (acc' ++ [y]) =
226 collectChain (acc' ++ [y]) ys
227 | otherwise = (acc', y : ys)