never executed always true always false
1 -- All extensions should be enabled explicitly due to doctest in this module.
2 {-# LANGUAGE ConstraintKinds #-}
3 {-# LANGUAGE DeriveGeneric #-}
4 {-# LANGUAGE FlexibleContexts #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE FunctionalDependencies #-}
7 {-# LANGUAGE GADTs #-}
8 {-# LANGUAGE ImportQualifiedPost #-}
9 {-# LANGUAGE LambdaCase #-}
10 {-# LANGUAGE NamedFieldPuns #-}
11 {-# LANGUAGE TypeFamilies #-}
12 {-# LANGUAGE UndecidableInstances #-}
13
14 {- |
15 Module : NITTA.Model.Problems.Refactor.OptimizeAccum
16 Description : Optimize an algorithm for Accum processor unit
17 Copyright : (c) Daniil Prohorov, 2021
18 License : BSD3
19 Maintainer : aleksandr.penskoi@gmail.com
20 Stability : experimental
21 -}
22 module NITTA.Model.Problems.Refactor.OptimizeAccum (
23 OptimizeAccum (..),
24 OptimizeAccumProblem (..),
25 )
26 where
27
28 import Data.List qualified as L
29 import Data.Map qualified as M
30 import Data.Maybe
31 import Data.Set qualified as S
32 import GHC.Generics
33 import NITTA.Intermediate.Functions
34 import NITTA.Intermediate.Types
35
36 {- | OptimizeAccum example:
37
38 > OptimizeAccum [+a +b => tmp1; +tmp1 +c => res] [+a +b +c => d]
39
40 before:
41
42 > [+a +b => tmp1; +tmp1 +c => res]
43
44 after:
45
46 > [+a +b +c => res]
47
48 == Doctest optimize example
49
50 >>> let a = constant 1 ["a"]
51 >>> let b = constant 2 ["b"]
52 >>> let c = constant 3 ["c"]
53 >>> let tmp1 = add "a" "b" ["tmp1"]
54 >>> let res = add "tmp1" "c" ["res"]
55 >>> let loopRes = loop 1 "e" ["res"]
56 >>> let fs = [a, b, c, tmp1, res, loopRes] :: [F String Int]
57 >>> optimizeAccumDecision fs $ head $ optimizeAccumOptions fs
58 [Acc(+a +b +c = res),const(1) = a,const(2) = b,const(3) = c,loop(1, e) = res]
59 -}
60 data OptimizeAccum v x = OptimizeAccum
61 { refOld :: [F v x]
62 , refNew :: [F v x]
63 }
64 deriving (Generic, Show, Eq)
65
66 class OptimizeAccumProblem u v x | u -> v x where
67 -- | Function takes algorithm in 'DataFlowGraph' and return list of 'Refactor' that can be done
68 optimizeAccumOptions :: u -> [OptimizeAccum v x]
69 optimizeAccumOptions _ = []
70
71 -- | Function takes 'OptimizeAccum' and modify 'DataFlowGraph'
72 optimizeAccumDecision :: u -> OptimizeAccum v x -> u
73 optimizeAccumDecision _ _ = error "not implemented"
74
75 instance (Var v, Val x) => OptimizeAccumProblem [F v x] v x where
76 optimizeAccumOptions fs = res
77 where
78 res =
79 L.nub
80 [ OptimizeAccum{refOld, refNew}
81 | refOld <- selectClusters $ filter isSupportByAccum fs
82 , let refNew = optimizeCluster refOld
83 , S.fromList refOld /= S.fromList refNew
84 ]
85
86 optimizeAccumDecision fs OptimizeAccum{refOld, refNew} = refNew <> (fs L.\\ refOld)
87
88 selectClusters fs =
89 L.nubBy
90 (\a b -> S.fromList a == S.fromList b)
91 [ [f, f']
92 | f <- fs
93 , f' <- fs
94 , f' /= f
95 , inputOutputIntersect f f'
96 ]
97 where
98 inputOutputIntersect f1 f2 = isIntersection (inputs f1) (outputs f2) || isIntersection (inputs f2) (outputs f1)
99 isIntersection a b = not $ S.disjoint a b
100
101 isSupportByAccum f
102 | Just Add{} <- castF f = True
103 | Just Sub{} <- castF f = True
104 | Just Neg{} <- castF f = True
105 | Just Acc{} <- castF f = True
106 | otherwise = False
107
108 -- | Create Map String (HistoryTree (F v x)), where Key is input label and Value is FU that contain this input label
109 containerMapCreate fs =
110 M.unions $
111 map
112 ( \f ->
113 foldl
114 ( \dataMap k ->
115 M.insertWith (++) k [f] dataMap
116 )
117 M.empty
118 (S.toList $ inputs f)
119 )
120 fs
121
122 -- | Takes container and refactor it, if it can be
123 optimizeCluster fs = concatMap refactored fs
124 where
125 containerMap = containerMapCreate fs
126
127 refactored f =
128 concatMap
129 ( \o ->
130 case M.findWithDefault [] o containerMap of
131 [] -> []
132 matchedFUs -> concatMap (refactorFunction f) matchedFUs
133 )
134 (S.toList $ outputs f)
135
136 refactorFunction f' f
137 | Just (Acc lst') <- castF f'
138 , Just (Acc lst) <- castF f
139 , let singleOutBool = (1 ==) $ length $ outputs f'
140 isOutInpIntersect =
141 any
142 ( \case
143 Push _ (I v) -> elem v $ outputs f'
144 _ -> False
145 )
146 lst
147 makeRefactor = singleOutBool && isOutInpIntersect
148 in makeRefactor =
149 let subs _ (Push Minus _) (Push Plus v) = Just $ Push Minus v
150 subs _ (Push Minus _) (Push Minus v) = Just $ Push Plus v
151 subs _ (Push Plus _) push@(Push _ _) = Just push
152 subs v _ pull@(Pull _) = deleteFromPull v pull
153 subs _ _ _ = error "Pull can not be here"
154
155 refactorAcc _ _ (Pull o) = [Pull o]
156 refactorAcc accList accNew accOld@(Push s i@(I v))
157 | elem v $ outputs accNew = mapMaybe (subs v accOld) accList
158 | s == Minus = [Push Minus i]
159 | s == Plus = [Push Plus i]
160 refactorAcc _ _ (Push _ (I _)) = undefined
161 in [packF $ Acc $ concatMap (refactorAcc lst' f') lst]
162 | Just f1 <- fromAddSub f'
163 , Just f2 <- fromAddSub f
164 , (1 ==) $ length $ outputs f' = case refactorFunction f1 f2 of
165 [fNew] -> [fNew]
166 _ -> [f, f']
167 | otherwise = [f, f']
168
169 deleteFromPull v (Pull (O s))
170 | S.null deleted = Nothing
171 | otherwise = Just $ Pull $ O deleted
172 where
173 deleted = S.delete v s
174 deleteFromPull _ (Push _ _) = error "delete only Pull"
175
176 fromAddSub f
177 | Just (Add in1 in2 (O out)) <- castF f =
178 Just $
179 acc $
180 [Push Plus in1, Push Plus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
181 | Just (Sub in1 in2 (O out)) <- castF f =
182 Just $
183 acc $
184 [Push Plus in1, Push Minus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
185 | Just (Neg in1 (O out)) <- castF f =
186 Just $
187 acc $
188 Push Minus in1 : [Pull $ O $ S.singleton o | o <- S.toList out]
189 | Just Acc{} <- castF f = Just f
190 | otherwise = Nothing