never executed always true always false
    1 -- All extensions should be enabled explicitly due to doctest in this module.
    2 {-# LANGUAGE ConstraintKinds #-}
    3 {-# LANGUAGE DeriveGeneric #-}
    4 {-# LANGUAGE FlexibleContexts #-}
    5 {-# LANGUAGE FlexibleInstances #-}
    6 {-# LANGUAGE FunctionalDependencies #-}
    7 {-# LANGUAGE GADTs #-}
    8 {-# LANGUAGE ImportQualifiedPost #-}
    9 {-# LANGUAGE LambdaCase #-}
   10 {-# LANGUAGE NamedFieldPuns #-}
   11 {-# LANGUAGE TypeFamilies #-}
   12 {-# LANGUAGE UndecidableInstances #-}
   13 
   14 {- |
   15 Module      : NITTA.Model.Problems.Refactor.OptimizeAccum
   16 Description : Optimize an algorithm for Accum processor unit
   17 Copyright   : (c) Daniil Prohorov, 2021
   18 License     : BSD3
   19 Maintainer  : aleksandr.penskoi@gmail.com
   20 Stability   : experimental
   21 -}
   22 module NITTA.Model.Problems.Refactor.OptimizeAccum (
   23     OptimizeAccum (..),
   24     OptimizeAccumProblem (..),
   25 )
   26 where
   27 
   28 import Data.List qualified as L
   29 import Data.Map qualified as M
   30 import Data.Maybe
   31 import Data.Set qualified as S
   32 import GHC.Generics
   33 import NITTA.Intermediate.Functions
   34 import NITTA.Intermediate.Types
   35 
   36 {- | OptimizeAccum example:
   37 
   38 > OptimizeAccum [+a +b => tmp1; +tmp1 +c => res] [+a +b +c => d]
   39 
   40 before:
   41 
   42 > [+a +b => tmp1; +tmp1 +c => res]
   43 
   44 after:
   45 
   46 > [+a +b +c => res]
   47 
   48 == Doctest optimize example
   49 
   50 >>> let a = constant 1 ["a"]
   51 >>> let b = constant 2 ["b"]
   52 >>> let c = constant 3 ["c"]
   53 >>> let tmp1 = add "a" "b" ["tmp1"]
   54 >>> let res = add "tmp1" "c" ["res"]
   55 >>> let loopRes = loop 1 "e" ["res"]
   56 >>> let fs = [a, b, c, tmp1, res, loopRes] :: [F String Int]
   57 >>> optimizeAccumDecision fs $ head $ optimizeAccumOptions fs
   58 [Acc(+a +b +c = res),const(1) = a,const(2) = b,const(3) = c,loop(1, e) = res]
   59 -}
   60 data OptimizeAccum v x = OptimizeAccum
   61     { refOld :: [F v x]
   62     , refNew :: [F v x]
   63     }
   64     deriving (Generic, Show, Eq)
   65 
   66 class OptimizeAccumProblem u v x | u -> v x where
   67     -- | Function takes algorithm in 'DataFlowGraph' and return list of 'Refactor' that can be done
   68     optimizeAccumOptions :: u -> [OptimizeAccum v x]
   69     optimizeAccumOptions _ = []
   70 
   71     -- | Function takes 'OptimizeAccum' and modify 'DataFlowGraph'
   72     optimizeAccumDecision :: u -> OptimizeAccum v x -> u
   73     optimizeAccumDecision _ _ = error "not implemented"
   74 
   75 instance (Var v, Val x) => OptimizeAccumProblem [F v x] v x where
   76     optimizeAccumOptions fs = res
   77         where
   78             res =
   79                 L.nub
   80                     [ OptimizeAccum{refOld, refNew}
   81                     | refOld <- selectClusters $ filter isSupportByAccum fs
   82                     , let refNew = optimizeCluster refOld
   83                     , S.fromList refOld /= S.fromList refNew
   84                     ]
   85 
   86     optimizeAccumDecision fs OptimizeAccum{refOld, refNew} = refNew <> (fs L.\\ refOld)
   87 
   88 selectClusters fs =
   89     L.nubBy
   90         (\a b -> S.fromList a == S.fromList b)
   91         [ [f, f']
   92         | f <- fs
   93         , f' <- fs
   94         , f' /= f
   95         , inputOutputIntersect f f'
   96         ]
   97     where
   98         inputOutputIntersect f1 f2 = isIntersection (inputs f1) (outputs f2) || isIntersection (inputs f2) (outputs f1)
   99         isIntersection a b = not $ S.disjoint a b
  100 
  101 isSupportByAccum f
  102     | Just Add{} <- castF f = True
  103     | Just Sub{} <- castF f = True
  104     | Just Neg{} <- castF f = True
  105     | Just Acc{} <- castF f = True
  106     | otherwise = False
  107 
  108 -- | Create Map String (HistoryTree (F v x)), where Key is input label and Value is FU that contain this input label
  109 containerMapCreate fs =
  110     M.unions $
  111         map
  112             ( \f ->
  113                 foldl
  114                     ( \dataMap k ->
  115                         M.insertWith (++) k [f] dataMap
  116                     )
  117                     M.empty
  118                     (S.toList $ inputs f)
  119             )
  120             fs
  121 
  122 -- | Takes container and refactor it, if it can be
  123 optimizeCluster fs = concatMap refactored fs
  124     where
  125         containerMap = containerMapCreate fs
  126 
  127         refactored f =
  128             concatMap
  129                 ( \o ->
  130                     case M.findWithDefault [] o containerMap of
  131                         [] -> []
  132                         matchedFUs -> concatMap (refactorFunction f) matchedFUs
  133                 )
  134                 (S.toList $ outputs f)
  135 
  136 refactorFunction f' f
  137     | Just (Acc lst') <- castF f'
  138     , Just (Acc lst) <- castF f
  139     , let singleOutBool = (1 ==) $ length $ outputs f'
  140           isOutInpIntersect =
  141             any
  142                 ( \case
  143                     Push _ (I v) -> elem v $ outputs f'
  144                     _ -> False
  145                 )
  146                 lst
  147           makeRefactor = singleOutBool && isOutInpIntersect
  148        in makeRefactor =
  149         let subs _ (Push Minus _) (Push Plus v) = Just $ Push Minus v
  150             subs _ (Push Minus _) (Push Minus v) = Just $ Push Plus v
  151             subs _ (Push Plus _) push@(Push _ _) = Just push
  152             subs v _ pull@(Pull _) = deleteFromPull v pull
  153             subs _ _ _ = error "Pull can not be here"
  154 
  155             refactorAcc _ _ (Pull o) = [Pull o]
  156             refactorAcc accList accNew accOld@(Push s i@(I v))
  157                 | elem v $ outputs accNew = mapMaybe (subs v accOld) accList
  158                 | s == Minus = [Push Minus i]
  159                 | s == Plus = [Push Plus i]
  160             refactorAcc _ _ (Push _ (I _)) = undefined
  161          in [packF $ Acc $ concatMap (refactorAcc lst' f') lst]
  162     | Just f1 <- fromAddSub f'
  163     , Just f2 <- fromAddSub f
  164     , (1 ==) $ length $ outputs f' = case refactorFunction f1 f2 of
  165         [fNew] -> [fNew]
  166         _ -> [f, f']
  167     | otherwise = [f, f']
  168 
  169 deleteFromPull v (Pull (O s))
  170     | S.null deleted = Nothing
  171     | otherwise = Just $ Pull $ O deleted
  172     where
  173         deleted = S.delete v s
  174 deleteFromPull _ (Push _ _) = error "delete only Pull"
  175 
  176 fromAddSub f
  177     | Just (Add in1 in2 (O out)) <- castF f =
  178         Just $
  179             acc $
  180                 [Push Plus in1, Push Plus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
  181     | Just (Sub in1 in2 (O out)) <- castF f =
  182         Just $
  183             acc $
  184                 [Push Plus in1, Push Minus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
  185     | Just (Neg in1 (O out)) <- castF f =
  186         Just $
  187             acc $
  188                 Push Minus in1 : [Pull $ O $ S.singleton o | o <- S.toList out]
  189     | Just Acc{} <- castF f = Just f
  190     | otherwise = Nothing