never executed always true always false
1 -- All extensions should be enabled explicitly due to doctest in this module.
2 {-# LANGUAGE ConstraintKinds #-}
3 {-# LANGUAGE DeriveGeneric #-}
4 {-# LANGUAGE FlexibleContexts #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE FunctionalDependencies #-}
7 {-# LANGUAGE GADTs #-}
8 {-# LANGUAGE ImportQualifiedPost #-}
9 {-# LANGUAGE LambdaCase #-}
10 {-# LANGUAGE NamedFieldPuns #-}
11 {-# LANGUAGE TypeFamilies #-}
12 {-# LANGUAGE UndecidableInstances #-}
13
14 {- |
15 Module : NITTA.Model.Problems.Refactor.OptimizeAccum
16 Description : Optimize an algorithm for Accum processor unit
17 Copyright : (c) Daniil Prohorov, 2021
18 License : BSD3
19 Maintainer : aleksandr.penskoi@gmail.com
20 Stability : experimental
21 -}
22 module NITTA.Model.Problems.Refactor.OptimizeAccum (
23 OptimizeAccum (..),
24 OptimizeAccumProblem (..),
25 ) where
26
27 import Data.List qualified as L
28 import Data.Map qualified as M
29 import Data.Maybe
30 import Data.Set qualified as S
31 import GHC.Generics
32 import NITTA.Intermediate.Functions
33 import NITTA.Intermediate.Types
34
35 {- | OptimizeAccum example:
36
37 > OptimizeAccum [+a +b => tmp1; +tmp1 +c => res] [+a +b +c => d]
38
39 before:
40
41 > [+a +b => tmp1; +tmp1 +c => res]
42
43 after:
44
45 > [+a +b +c => res]
46
47 == Doctest optimize example
48
49 >>> let a = constant 1 ["a"]
50 >>> let b = constant 2 ["b"]
51 >>> let c = constant 3 ["c"]
52 >>> let tmp1 = add "a" "b" ["tmp1"]
53 >>> let res = add "tmp1" "c" ["res"]
54 >>> let loopRes = loop 1 "e" ["res"]
55 >>> let fs = [a, b, c, tmp1, res, loopRes] :: [F String Int]
56 >>> optimizeAccumDecision fs $ head $ optimizeAccumOptions fs
57 [Acc(+a +b +c = res),const(1) = a,const(2) = b,const(3) = c,loop(1, e) = res]
58 -}
59 data OptimizeAccum v x = OptimizeAccum
60 { refOld :: [F v x]
61 , refNew :: [F v x]
62 }
63 deriving (Generic, Show, Eq)
64
65 class OptimizeAccumProblem u v x | u -> v x where
66 -- | Function takes algorithm in 'DataFlowGraph' and return list of 'Refactor' that can be done
67 optimizeAccumOptions :: u -> [OptimizeAccum v x]
68 optimizeAccumOptions _ = []
69
70 -- | Function takes 'OptimizeAccum' and modify 'DataFlowGraph'
71 optimizeAccumDecision :: u -> OptimizeAccum v x -> u
72 optimizeAccumDecision _ _ = error "not implemented"
73
74 instance (Var v, Val x) => OptimizeAccumProblem [F v x] v x where
75 optimizeAccumOptions fs = res
76 where
77 res =
78 L.nub
79 [ OptimizeAccum{refOld, refNew}
80 | refOld <- selectClusters $ filter isSupportByAccum fs
81 , let refNew = optimizeCluster refOld
82 , S.fromList refOld /= S.fromList refNew
83 ]
84
85 optimizeAccumDecision fs OptimizeAccum{refOld, refNew} = refNew <> (fs L.\\ refOld)
86
87 selectClusters fs =
88 L.nubBy
89 (\a b -> S.fromList a == S.fromList b)
90 [ [f, f']
91 | f <- fs
92 , f' <- fs
93 , f' /= f
94 , inputOutputIntersect f f'
95 ]
96 where
97 inputOutputIntersect f1 f2 = isIntersection (inputs f1) (outputs f2) || isIntersection (inputs f2) (outputs f1)
98 isIntersection a b = not $ S.disjoint a b
99
100 isSupportByAccum f
101 | Just Add{} <- castF f = True
102 | Just Sub{} <- castF f = True
103 | Just Neg{} <- castF f = True
104 | Just Acc{} <- castF f = True
105 | otherwise = False
106
107 -- | Create Map String (HistoryTree (F v x)), where Key is input label and Value is FU that contain this input label
108 containerMapCreate fs =
109 M.unions $
110 map
111 ( \f ->
112 foldl
113 ( \dataMap k ->
114 M.insertWith (++) k [f] dataMap
115 )
116 M.empty
117 (S.toList $ inputs f)
118 )
119 fs
120
121 -- | Takes container and refactor it, if it can be
122 optimizeCluster fs = concatMap refactored fs
123 where
124 containerMap = containerMapCreate fs
125
126 refactored f =
127 concatMap
128 ( \o ->
129 case M.findWithDefault [] o containerMap of
130 [] -> []
131 matchedFUs -> concatMap (refactorFunction f) matchedFUs
132 )
133 (S.toList $ outputs f)
134
135 refactorFunction f' f
136 | Just (Acc lst') <- castF f'
137 , Just (Acc lst) <- castF f
138 , let singleOutBool = (1 ==) $ length $ outputs f'
139 isOutInpIntersect =
140 any
141 ( \case
142 Push _ (I v) -> elem v $ outputs f'
143 _ -> False
144 )
145 lst
146 makeRefactor = singleOutBool && isOutInpIntersect
147 in makeRefactor =
148 let subs _ (Push Minus _) (Push Plus v) = Just $ Push Minus v
149 subs _ (Push Minus _) (Push Minus v) = Just $ Push Plus v
150 subs _ (Push Plus _) push@(Push _ _) = Just push
151 subs v _ pull@(Pull _) = deleteFromPull v pull
152 subs _ _ _ = error "Pull can not be here"
153
154 refactorAcc _ _ (Pull o) = [Pull o]
155 refactorAcc accList accNew accOld@(Push s i@(I v))
156 | elem v $ outputs accNew = mapMaybe (subs v accOld) accList
157 | s == Minus = [Push Minus i]
158 | s == Plus = [Push Plus i]
159 refactorAcc _ _ (Push _ (I _)) = undefined
160 in [packF $ Acc $ concatMap (refactorAcc lst' f') lst]
161 | Just f1 <- fromAddSub f'
162 , Just f2 <- fromAddSub f
163 , (1 ==) $ length $ outputs f' = case refactorFunction f1 f2 of
164 [fNew] -> [fNew]
165 _ -> [f, f']
166 | otherwise = [f, f']
167
168 deleteFromPull v (Pull (O s))
169 | S.null deleted = Nothing
170 | otherwise = Just $ Pull $ O deleted
171 where
172 deleted = S.delete v s
173 deleteFromPull _ (Push _ _) = error "delete only Pull"
174
175 fromAddSub f
176 | Just (Add in1 in2 (O out)) <- castF f =
177 Just $
178 acc $
179 [Push Plus in1, Push Plus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
180 | Just (Sub in1 in2 (O out)) <- castF f =
181 Just $
182 acc $
183 [Push Plus in1, Push Minus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
184 | Just (Neg in1 (O out)) <- castF f =
185 Just $
186 acc $
187 Push Minus in1 : [Pull $ O $ S.singleton o | o <- S.toList out]
188 | Just Acc{} <- castF f = Just f
189 | otherwise = Nothing