never executed always true always false
    1 {-# LANGUAGE GADTs #-}
    2 {-# LANGUAGE MultiParamTypeClasses #-}
    3 {-# LANGUAGE NoMonomorphismRestriction #-}
    4 
    5 {-# OPTIONS -fno-warn-orphans #-}
    6 
    7 {- |
    8 Module      : NITTA.Synthesis.Steps.Allocation
    9 Description : Implementation of SynthesisDecisionCls that allows to allocate PUs
   10 Copyright   : (c) Aleksandr Penskoi, Vitaliy Zakusilo, 2022
   11 License     : BSD3
   12 Maintainer  : aleksandr.penskoi@gmail.com
   13 Stability   : experimental
   14 -}
   15 module NITTA.Synthesis.Steps.Allocation (
   16     AllocationMetrics (..),
   17 ) where
   18 
   19 import Data.Aeson (ToJSON)
   20 import Data.Map qualified as M
   21 import GHC.Generics (Generic)
   22 import NITTA.Intermediate.Analysis (ProcessWave (ProcessWave, pwFs))
   23 import NITTA.Model.Networks.Bus (BusNetwork (bnPUPrototypes, bnPus, bnRemains))
   24 import NITTA.Model.Networks.Types (PU (PU, unit), PUPrototype (..))
   25 import NITTA.Model.Problems.Allocation (
   26     Allocation (Allocation, processUnitTag),
   27     AllocationProblem (allocationDecision),
   28  )
   29 import NITTA.Model.ProcessorUnits.Types (
   30     ParallelismType (..),
   31     ProcessorUnit (parallelismType),
   32     UnitTag,
   33     allowToProcess,
   34  )
   35 import NITTA.Model.TargetSystem (TargetSystem (TargetSystem, mUnit))
   36 import NITTA.Synthesis.Types (
   37     SynthesisDecisionCls (..),
   38     SynthesisState (SynthesisState, numberOfProcessWaves, processWaves, sTarget),
   39  )
   40 
   41 data AllocationMetrics = AllocationMetrics
   42     { mParallelism :: ParallelismType
   43     -- ^ PU prototype parallelism type
   44     , mRelatedRemains :: Float
   45     -- ^ The number of remaining functions that can be bound to pu
   46     , mMinPusForRemains :: Float
   47     -- ^ The minimum number of PUs for each of the remaining functions that can process it
   48     , mMaxParallels :: Float
   49     -- ^ The maximum number of functions that could be processed in parallel
   50     , mAvgParallels :: Float
   51     -- ^ The number of functions that can be processed in parallel on average
   52     }
   53     deriving (Generic)
   54 
   55 instance ToJSON AllocationMetrics
   56 
   57 instance
   58     UnitTag tag =>
   59     SynthesisDecisionCls
   60         (SynthesisState (TargetSystem (BusNetwork tag v x t) tag v x t) tag v x t)
   61         (TargetSystem (BusNetwork tag v x t) tag v x t)
   62         (Allocation tag)
   63         (Allocation tag)
   64         AllocationMetrics
   65     where
   66     decisions SynthesisState{sTarget} o = [(o, allocationDecision sTarget o)]
   67 
   68     parameters SynthesisState{sTarget = TargetSystem{mUnit}, processWaves, numberOfProcessWaves} Allocation{processUnitTag} _ =
   69         let pus = M.elems $ bnPus mUnit
   70             tmp = bnPUPrototypes mUnit M.! processUnitTag
   71             mParallelism PUPrototype{pProto} = parallelismType pProto
   72             canProcessTmp PUPrototype{pProto} f = allowToProcess f pProto
   73             canProcessPU PU{unit} f = allowToProcess f unit
   74             relatedRemains = filter (canProcessTmp tmp) $ bnRemains mUnit
   75             fCountByWaves = map (\ProcessWave{pwFs} -> length $ filter (canProcessTmp tmp) pwFs) processWaves
   76          in AllocationMetrics
   77                 { mParallelism = mParallelism tmp
   78                 , mRelatedRemains = fromIntegral $ length relatedRemains
   79                 , mMinPusForRemains = fromIntegral $ foldr (min . (\f -> length $ filter (`canProcessPU` f) pus)) (maxBound :: Int) relatedRemains
   80                 , mMaxParallels = fromIntegral $ maximum fCountByWaves
   81                 , mAvgParallels = (fromIntegral (sum fCountByWaves) :: Float) / (fromIntegral numberOfProcessWaves :: Float)
   82                 }
   83 
   84     estimate _ctx _o _d AllocationMetrics{mParallelism, mMinPusForRemains, mAvgParallels}
   85         | mMinPusForRemains == 0 = 5000
   86         | mParallelism == Full = -1
   87         | mParallelism == Pipeline && (mAvgParallels / mMinPusForRemains >= 3) = 4900
   88         | mAvgParallels / mMinPusForRemains >= 2 = 4900
   89         | otherwise = -1