never executed always true always false
1 {-# LANGUAGE GADTs #-}
2 {-# LANGUAGE MultiWayIf #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE RankNTypes #-}
5 {-# LANGUAGE NoMonomorphismRestriction #-}
6
7 {- |
8 Module : NITTA.Synthesis.Explore
9 Description : Explore synthesis tree
10 Copyright : (c) Aleksandr Penskoi, 2021
11 License : BSD3
12 Maintainer : aleksandr.penskoi@gmail.com
13 Stability : experimental
14 -}
15 module NITTA.Synthesis.Explore (
16 synthesisTreeRootIO,
17 getTreeIO,
18 getTreePathIO,
19 subForestIO,
20 positiveSubForestIO,
21 ) where
22
23 import Control.Concurrent.STM
24 import Control.Exception
25 import Control.Monad (foldM, forM, unless, when)
26 import Data.Default
27 import Data.Map.Strict qualified as M
28 import Data.Maybe
29 import Data.Set qualified as S
30 import Data.Text qualified as T
31 import Debug.Trace (trace)
32 import NITTA.Intermediate.Analysis (buildProcessWaves, estimateVarWaves)
33 import NITTA.Intermediate.Types
34 import NITTA.Model.Networks.Bus
35 import NITTA.Model.Problems.Allocation
36 import NITTA.Model.Problems.Bind
37 import NITTA.Model.Problems.Dataflow
38 import NITTA.Model.Problems.Refactor
39 import NITTA.Model.TargetSystem
40 import NITTA.Synthesis.MlBackend.Client
41 import NITTA.Synthesis.MlBackend.ServerInstance
42 import NITTA.Synthesis.Types
43 import NITTA.UIBackend.Types
44 import NITTA.UIBackend.ViewHelper
45 import NITTA.Utils
46 import Network.HTTP.Simple
47 import System.Log.Logger
48
49 -- | Make synthesis tree
50 synthesisTreeRootIO = atomically . rootSynthesisTreeSTM
51
52 rootSynthesisTreeSTM model = do
53 sSubForestVar <- newEmptyTMVar
54 let sState = nodeCtx Nothing model
55 return
56 Tree
57 { sID = def
58 , sState
59 , sDecision = Root
60 , sSubForestVar
61 , isLeaf = isLeaf' sState
62 , isComplete = isComplete' sState
63 }
64
65 -- | Get specific by @nId@ node from a synthesis tree.
66 getTreeIO _ctx tree (Sid []) = return tree
67 getTreeIO ctx tree (Sid (i : is)) = do
68 subForest <- subForestIO ctx tree
69 unless (i < length subForest) $ error "getTreeIO - wrong Sid"
70 getTreeIO ctx (subForest !! i) (Sid is)
71
72 -- | Get list of all nodes from root to selected.
73 getTreePathIO _ctx _tree (Sid []) = return []
74 getTreePathIO ctx tree (Sid (i : is)) = do
75 h <- getTreeIO ctx tree $ Sid [i]
76 t <- getTreePathIO ctx h $ Sid is
77 return $ h : t
78
79 {- | Get all available edges for the node. Edges calculated only for the first
80 call.
81 -}
82 subForestIO
83 BackendCtx{nodeScores, mlBackendGetter}
84 tree@Tree{sSubForestVar} = do
85 (firstTime, subForest) <-
86 atomically $
87 tryReadTMVar sSubForestVar >>= \case
88 Just subForest -> return (False, subForest)
89 Nothing -> do
90 subForest <- exploreSubForestVar tree
91 putTMVar sSubForestVar subForest
92 return (True, subForest)
93
94 when firstTime $ traceProcessedNode tree
95
96 -- FIXME: ML scores are evaluated here every time subForestIO is called. how to cache it like the default score? IO in STM isn't possible.
97 -- also it looks inelegant, is there a way to refactor it?
98 let modelNames = mapMaybe (T.stripPrefix mlScoreKeyPrefix) nodeScores
99 if
100 | null subForest -> return subForest
101 | null nodeScores -> return subForest
102 | null modelNames -> return subForest
103 | otherwise -> do
104 MlBackendServer{baseUrl} <- mlBackendGetter
105 case baseUrl of
106 Nothing -> return subForest
107 Just mlBackendBaseUrl -> do
108 -- (addMlScoreToSubforestSkipErrorsIO subForestAccum modelName) gets called for each modelName
109 foldM (addMlScoreToSubforestSkipErrorsIO mlBackendBaseUrl) subForest modelNames
110 where
111 traceProcessedNode Tree{sID, sDecision} =
112 debugM "NITTA.Synthesis" $
113 "explore: "
114 <> show sID
115 <> " score: "
116 <> ( case sDecision of
117 SynthesisDecision{scores} -> show scores
118 _ -> "-"
119 )
120 <> " decision: "
121 <> ( case sDecision of
122 SynthesisDecision{decision} -> show decision
123 _ -> "-"
124 )
125
126 addMlScoreToSubforestSkipErrorsIO mlBackendBaseUrl subForest modelName = do
127 addMlScoreToSubforestIO mlBackendBaseUrl subForest modelName
128 `catch` \e -> do
129 errorM "NITTA.Synthesis" $
130 "ML backend error: "
131 <> ( case e of
132 JSONConversionException _ resp _ -> show resp
133 _ -> show e
134 )
135 return subForest
136
137 addMlScoreToSubforestIO mlBackendBaseUrl subForest modelName = do
138 let input = ScoringInput{scoringTarget = ScoringTargetAll, nodes = [view node | node <- subForest]}
139 allInputsScores <- predictScoresIO modelName mlBackendBaseUrl [input]
140 -- +20 shifts "useless node" threshold, since model outputs negative values much more often
141 -- FIXME: make models' output consist of mostly >0 values and treat 0 as a "useless node" threshold? training data changes required
142 let mlScores = map (+ 20) $ head allInputsScores
143 scoreKey = mlScoreKeyPrefix <> modelName
144
145 return $
146 map
147 (addNewScoreToSubforest scoreKey)
148 (zip subForest mlScores)
149
150 addNewScoreToSubforest scoreKey (node@Tree{sDecision = sDes@SynthesisDecision{scores = origScores}}, newScore) =
151 node{sDecision = sDes{scores = M.insert scoreKey newScore origScores}}
152 addNewScoreToSubforest scoreKey (node@Tree{sDecision = Root}, _) =
153 trace ("adding new score to Root, shouldn't happen, scoreKey: " ++ fromText scoreKey) node
154
155 {- | For synthesis method is more usefull, because throw away all useless trees in
156 subForest (objective function value less than zero).
157 -}
158 positiveSubForestIO ctx tree = filter ((> 0) . defScore . sDecision) <$> subForestIO ctx tree
159
160 isLeaf'
161 SynthesisState
162 { sAllocationOptions = []
163 , sBindOptions = []
164 , sDataflowOptions = []
165 , sBreakLoopOptions = []
166 , sResolveDeadlockOptions = []
167 , sOptimizeAccumOptions = []
168 , sConstantFoldingOptions = []
169 } = True
170 isLeaf' _ = False
171
172 isComplete' = isSynthesisComplete . sTarget
173
174 -- * Internal
175
176 exploreSubForestVar parent@Tree{sID, sState} =
177 let edges =
178 concat
179 ( map (decisionAndContext parent) (sAllocationOptions sState)
180 ++ map (decisionAndContext parent) (sBindOptions sState)
181 ++ map (decisionAndContext parent) (sDataflowOptions sState)
182 ++ map (decisionAndContext parent) (sBreakLoopOptions sState)
183 ++ map (decisionAndContext parent) (sResolveDeadlockOptions sState)
184 ++ map (decisionAndContext parent) (sOptimizeAccumOptions sState)
185 ++ map (decisionAndContext parent) (sConstantFoldingOptions sState)
186 )
187 in forM (zip [0 ..] edges) $ \(i, (desc, ctx')) -> do
188 sSubForestVar <- newEmptyTMVar
189 return
190 Tree
191 { sID = sID <> Sid [i]
192 , sState = ctx'
193 , sDecision = desc
194 , sSubForestVar
195 , isLeaf = isLeaf' ctx'
196 , isComplete = isComplete' ctx'
197 }
198
199 decisionAndContext parent@Tree{sState = ctx} o =
200 [ (SynthesisDecision o d p e, nodeCtx (Just parent) model)
201 | (d, model) <- decisions ctx o
202 , let p = parameters ctx o d
203 e = M.singleton "default" $ estimate ctx o d p
204 ]
205
206 nodeCtx parent nModel =
207 let sBindOptions = bindOptions nModel
208 sDataflowOptions = dataflowOptions nModel
209 fs = functions $ mDataFlowGraph nModel
210 processWaves = buildProcessWaves [] fs
211 in SynthesisState
212 { sTarget = nModel
213 , sParent = parent
214 , sAllocationOptions = allocationOptions nModel
215 , sBindOptions
216 , sDataflowOptions
217 , sResolveDeadlockOptions = resolveDeadlockOptions nModel
218 , sBreakLoopOptions = breakLoopOptions nModel
219 , sConstantFoldingOptions = constantFoldingOptions nModel
220 , sOptimizeAccumOptions = optimizeAccumOptions nModel
221 , bindingAlternative =
222 foldl
223 ( \st b -> case b of
224 (SingleBind uTag f) -> M.alter (return . maybe [uTag] (uTag :)) f st
225 _ -> st
226 )
227 M.empty
228 sBindOptions
229 , possibleDeadlockBinds =
230 S.fromList
231 [ f
232 | (SingleBind uTag f) <- sBindOptions
233 , Lock{lockBy} <- locks f
234 , lockBy `S.member` unionsMap variables (boundFunctions uTag $ mUnit nModel)
235 ]
236 , bindWaves = estimateVarWaves (S.elems (variables (mUnit nModel) S.\\ unionsMap variables sBindOptions)) fs
237 , processWaves
238 , numberOfProcessWaves = length processWaves
239 , numberOfDataflowOptions = length sDataflowOptions
240 , transferableVars =
241 S.unions
242 [ variables ep
243 | (DataflowSt _ targets) <- sDataflowOptions
244 , (_, ep) <- targets
245 ]
246 , unitWorkloadInFunction =
247 let
248 BusNetwork{bnBound, bnPus} = mUnit nModel
249 in
250 M.fromList
251 $ map
252 ( \uTag -> (uTag, maybe 0 length $ bnBound M.!? uTag)
253 )
254 $ M.keys bnPus
255 }