never executed always true always false
1 {-# LANGUAGE GADTs #-}
2 {-# LANGUAGE MultiWayIf #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE RankNTypes #-}
5 {-# LANGUAGE NoMonomorphismRestriction #-}
6
7 {- |
8 Module : NITTA.Synthesis.Explore
9 Description : Explore synthesis tree
10 Copyright : (c) Aleksandr Penskoi, 2021
11 License : BSD3
12 Maintainer : aleksandr.penskoi@gmail.com
13 Stability : experimental
14 -}
15 module NITTA.Synthesis.Explore (
16 synthesisTreeRootIO,
17 getTreeIO,
18 getTreePathIO,
19 subForestIO,
20 positiveSubForestIO,
21 ) where
22
23 import Control.Concurrent.STM
24 import Control.Exception
25 import Control.Monad (foldM, forM, unless, when)
26 import Data.Default
27 import Data.Map.Strict qualified as M
28 import Data.Maybe
29 import Data.Set qualified as S
30 import Data.Text qualified as T
31 import Debug.Trace (trace)
32 import NITTA.Intermediate.Analysis (buildProcessWaves, estimateVarWaves)
33 import NITTA.Intermediate.Types
34 import NITTA.Model.Networks.Bus
35 import NITTA.Model.Problems.Allocation
36 import NITTA.Model.Problems.Bind
37 import NITTA.Model.Problems.Dataflow
38 import NITTA.Model.Problems.Refactor
39 import NITTA.Model.TargetSystem
40 import NITTA.Synthesis.MlBackend.Client
41 import NITTA.Synthesis.MlBackend.ServerInstance
42 import NITTA.Synthesis.Types
43 import NITTA.UIBackend.Types
44 import NITTA.UIBackend.ViewHelper
45 import NITTA.Utils
46 import Network.HTTP.Simple
47 import System.Log.Logger
48
49 -- | Make synthesis tree
50 synthesisTreeRootIO = atomically . rootSynthesisTreeSTM
51
52 rootSynthesisTreeSTM model = do
53 sSubForestVar <- newEmptyTMVar
54 let sState = nodeCtx Nothing model
55 return
56 Tree
57 { sID = def
58 , sState
59 , sDecision = Root
60 , sSubForestVar
61 , isLeaf = isLeaf' sState
62 , isComplete = isComplete' sState
63 }
64
65 -- | Get specific by @nId@ node from a synthesis tree.
66 getTreeIO _ctx tree (Sid []) = return tree
67 getTreeIO ctx tree (Sid (i : is)) = do
68 subForest <- subForestIO ctx tree
69 unless (i < length subForest) $ error "getTreeIO - wrong Sid"
70 getTreeIO ctx (subForest !! i) (Sid is)
71
72 -- | Get list of all nodes from root to selected.
73 getTreePathIO _ctx _tree (Sid []) = return []
74 getTreePathIO ctx tree (Sid (i : is)) = do
75 h <- getTreeIO ctx tree $ Sid [i]
76 t <- getTreePathIO ctx h $ Sid is
77 return $ h : t
78
79 {- | Get all available edges for the node. Edges calculated only for the first
80 call.
81 -}
82 subForestIO
83 BackendCtx{nodeScores, mlBackendGetter}
84 tree@Tree{sSubForestVar} = do
85 (firstTime, subForest) <-
86 atomically $
87 tryReadTMVar sSubForestVar >>= \case
88 Just subForest -> return (False, subForest)
89 Nothing -> do
90 subForest <- exploreSubForestVar tree
91 putTMVar sSubForestVar subForest
92 return (True, subForest)
93
94 when firstTime $ traceProcessedNode tree
95
96 -- FIXME: ML scores are evaluated here every time subForestIO is called. how to cache it like the default score? IO in STM isn't possible.
97 -- also it looks inelegant, is there a way to refactor it?
98 let modelNames = mapMaybe (T.stripPrefix mlScoreKeyPrefix) nodeScores
99 if
100 | null subForest -> return subForest
101 | null nodeScores -> return subForest
102 | null modelNames -> return subForest
103 | otherwise -> do
104 MlBackendServer{baseUrl} <- mlBackendGetter
105 case baseUrl of
106 Nothing -> return subForest
107 Just mlBackendBaseUrl -> do
108 -- (addMlScoreToSubforestSkipErrorsIO subForestAccum modelName) gets called for each modelName
109 foldM (addMlScoreToSubforestSkipErrorsIO mlBackendBaseUrl) subForest modelNames
110 where
111 traceProcessedNode Tree{sID, sDecision} =
112 debugM "NITTA.Synthesis" $
113 "explore: "
114 <> show sID
115 <> " score: "
116 <> ( case sDecision of
117 SynthesisDecision{scores} -> show scores
118 _ -> "-"
119 )
120 <> " decision: "
121 <> ( case sDecision of
122 SynthesisDecision{decision} -> show decision
123 _ -> "-"
124 )
125
126 addMlScoreToSubforestSkipErrorsIO mlBackendBaseUrl subForest modelName = do
127 addMlScoreToSubforestIO mlBackendBaseUrl subForest modelName
128 `catch` \e -> do
129 errorM "NITTA.Synthesis" $
130 "ML backend error: "
131 <> ( case e of
132 JSONConversionException _ resp _ -> show resp
133 _ -> show e
134 )
135 return subForest
136
137 addMlScoreToSubforestIO mlBackendBaseUrl subForest modelName = do
138 let input = ScoringInput{scoringTarget = ScoringTargetAll, nodes = [view node | node <- subForest]}
139 allInputsScores <- predictScoresIO modelName mlBackendBaseUrl [input]
140 -- +20 shifts "useless node" threshold, since model outputs negative values much more often
141 -- FIXME: make models' output consist of mostly >0 values and treat 0 as a "useless node" threshold? training data changes required
142 let mlScores = map (+ 20) $ head allInputsScores
143 scoreKey = mlScoreKeyPrefix <> modelName
144
145 return $
146 map
147 (addNewScoreToSubforest scoreKey)
148 (zip subForest mlScores)
149
150 addNewScoreToSubforest scoreKey (node@Tree{sDecision = sDes@SynthesisDecision{scores = origScores}}, newScore) =
151 node{sDecision = sDes{scores = M.insert scoreKey newScore origScores}}
152 addNewScoreToSubforest scoreKey (node@Tree{sDecision = Root}, _) =
153 trace ("adding new score to Root, shouldn't happen, scoreKey: " ++ fromText scoreKey) node
154
155 {- | For synthesis method is more usefull, because throw away all useless trees in
156 subForest (objective function value less than zero).
157 -}
158 positiveSubForestIO ctx tree = filter ((> 0) . defScore . sDecision) <$> subForestIO ctx tree
159
160 isLeaf'
161 SynthesisState
162 { sAllocationOptions = []
163 , sBindOptions = []
164 , sDataflowOptions = []
165 , sBreakLoopOptions = []
166 , sResolveDeadlockOptions = []
167 , sOptimizeAccumOptions = []
168 , sOptimizeLogicalUnitOptions = []
169 , sConstantFoldingOptions = []
170 } = True
171 isLeaf' _ = False
172
173 isComplete' = isSynthesisComplete . sTarget
174
175 -- * Internal
176
177 exploreSubForestVar parent@Tree{sID, sState} =
178 let edges =
179 concat
180 ( map (decisionAndContext parent) (sAllocationOptions sState)
181 ++ map (decisionAndContext parent) (sBindOptions sState)
182 ++ map (decisionAndContext parent) (sDataflowOptions sState)
183 ++ map (decisionAndContext parent) (sBreakLoopOptions sState)
184 ++ map (decisionAndContext parent) (sResolveDeadlockOptions sState)
185 ++ map (decisionAndContext parent) (sOptimizeAccumOptions sState)
186 ++ map (decisionAndContext parent) (sOptimizeLogicalUnitOptions sState)
187 ++ map (decisionAndContext parent) (sConstantFoldingOptions sState)
188 )
189 in forM (zip [0 ..] edges) $ \(i, (desc, ctx')) -> do
190 sSubForestVar <- newEmptyTMVar
191 return
192 Tree
193 { sID = sID <> Sid [i]
194 , sState = ctx'
195 , sDecision = desc
196 , sSubForestVar
197 , isLeaf = isLeaf' ctx'
198 , isComplete = isComplete' ctx'
199 }
200
201 decisionAndContext parent@Tree{sState = ctx} o =
202 [ (SynthesisDecision o d p e, nodeCtx (Just parent) model)
203 | (d, model) <- decisions ctx o
204 , let p = parameters ctx o d
205 e = M.singleton "default" $ estimate ctx o d p
206 ]
207
208 nodeCtx parent nModel =
209 let sBindOptions = bindOptions nModel
210 sDataflowOptions = dataflowOptions nModel
211 fs = functions $ mDataFlowGraph nModel
212 processWaves = buildProcessWaves [] fs
213 in SynthesisState
214 { sTarget = nModel
215 , sParent = parent
216 , sAllocationOptions = allocationOptions nModel
217 , sBindOptions
218 , sDataflowOptions
219 , sResolveDeadlockOptions = resolveDeadlockOptions nModel
220 , sBreakLoopOptions = breakLoopOptions nModel
221 , sConstantFoldingOptions = constantFoldingOptions nModel
222 , sOptimizeAccumOptions = optimizeAccumOptions nModel
223 , sOptimizeLogicalUnitOptions = optimizeLogicalUnitOptions nModel
224 , bindingAlternative =
225 foldl
226 ( \st b -> case b of
227 (SingleBind uTag f) -> M.alter (return . maybe [uTag] (uTag :)) f st
228 _ -> st
229 )
230 M.empty
231 sBindOptions
232 , possibleDeadlockBinds =
233 S.fromList
234 [ f
235 | (SingleBind uTag f) <- sBindOptions
236 , Lock{lockBy} <- locks f
237 , lockBy `S.member` unionsMap variables (boundFunctions uTag $ mUnit nModel)
238 ]
239 , bindWaves = estimateVarWaves (S.elems (variables (mUnit nModel) S.\\ unionsMap variables sBindOptions)) fs
240 , processWaves
241 , numberOfProcessWaves = length processWaves
242 , numberOfDataflowOptions = length sDataflowOptions
243 , transferableVars =
244 S.unions
245 [ variables ep
246 | (DataflowSt _ targets) <- sDataflowOptions
247 , (_, ep) <- targets
248 ]
249 , unitWorkloadInFunction =
250 let
251 BusNetwork{bnBound, bnPus} = mUnit nModel
252 in
253 M.fromList
254 $ map
255 (\uTag -> (uTag, maybe 0 length $ bnBound M.!? uTag))
256 $ M.keys bnPus
257 }