never executed always true always false
    1 -- All extensions should be enabled explicitly due to doctest in this module.
    2 {-# LANGUAGE ConstraintKinds #-}
    3 {-# LANGUAGE DeriveGeneric #-}
    4 {-# LANGUAGE FlexibleContexts #-}
    5 {-# LANGUAGE FlexibleInstances #-}
    6 {-# LANGUAGE FunctionalDependencies #-}
    7 {-# LANGUAGE GADTs #-}
    8 {-# LANGUAGE ImportQualifiedPost #-}
    9 {-# LANGUAGE LambdaCase #-}
   10 {-# LANGUAGE NamedFieldPuns #-}
   11 {-# LANGUAGE TypeFamilies #-}
   12 {-# LANGUAGE UndecidableInstances #-}
   13 
   14 {- |
   15 Module      : NITTA.Model.Problems.Refactor.OptimizeAccum
   16 Description : Optimize an algorithm for Accum processor unit
   17 Copyright   : (c) Daniil Prohorov, 2021
   18 License     : BSD3
   19 Maintainer  : aleksandr.penskoi@gmail.com
   20 Stability   : experimental
   21 -}
   22 module NITTA.Model.Problems.Refactor.OptimizeAccum (
   23     OptimizeAccum (..),
   24     OptimizeAccumProblem (..),
   25 ) where
   26 
   27 import Data.List qualified as L
   28 import Data.Map qualified as M
   29 import Data.Maybe
   30 import Data.Set qualified as S
   31 import GHC.Generics
   32 import NITTA.Intermediate.Functions
   33 import NITTA.Intermediate.Types
   34 
   35 {- | OptimizeAccum example:
   36 
   37 > OptimizeAccum [+a +b => tmp1; +tmp1 +c => res] [+a +b +c => d]
   38 
   39 before:
   40 
   41 > [+a +b => tmp1; +tmp1 +c => res]
   42 
   43 after:
   44 
   45 > [+a +b +c => res]
   46 
   47 == Doctest optimize example
   48 
   49 >>> let a = constant 1 ["a"]
   50 >>> let b = constant 2 ["b"]
   51 >>> let c = constant 3 ["c"]
   52 >>> let tmp1 = add "a" "b" ["tmp1"]
   53 >>> let res = add "tmp1" "c" ["res"]
   54 >>> let loopRes = loop 1 "e" ["res"]
   55 >>> let fs = [a, b, c, tmp1, res, loopRes] :: [F String Int]
   56 >>> optimizeAccumDecision fs $ head $ optimizeAccumOptions fs
   57 [Acc(+a +b +c = res),const(1) = a,const(2) = b,const(3) = c,loop(1, e) = res]
   58 -}
   59 data OptimizeAccum v x = OptimizeAccum
   60     { refOld :: [F v x]
   61     , refNew :: [F v x]
   62     }
   63     deriving (Generic, Show, Eq)
   64 
   65 class OptimizeAccumProblem u v x | u -> v x where
   66     -- | Function takes algorithm in 'DataFlowGraph' and return list of 'Refactor' that can be done
   67     optimizeAccumOptions :: u -> [OptimizeAccum v x]
   68     optimizeAccumOptions _ = []
   69 
   70     -- | Function takes 'OptimizeAccum' and modify 'DataFlowGraph'
   71     optimizeAccumDecision :: u -> OptimizeAccum v x -> u
   72     optimizeAccumDecision _ _ = error "not implemented"
   73 
   74 instance (Var v, Val x) => OptimizeAccumProblem [F v x] v x where
   75     optimizeAccumOptions fs = res
   76         where
   77             res =
   78                 L.nub
   79                     [ OptimizeAccum{refOld, refNew}
   80                     | refOld <- selectClusters $ filter isSupportByAccum fs
   81                     , let refNew = optimizeCluster refOld
   82                     , S.fromList refOld /= S.fromList refNew
   83                     ]
   84 
   85     optimizeAccumDecision fs OptimizeAccum{refOld, refNew} = refNew <> (fs L.\\ refOld)
   86 
   87 selectClusters fs =
   88     L.nubBy
   89         (\a b -> S.fromList a == S.fromList b)
   90         [ [f, f']
   91         | f <- fs
   92         , f' <- fs
   93         , f' /= f
   94         , inputOutputIntersect f f'
   95         ]
   96     where
   97         inputOutputIntersect f1 f2 = isIntersection (inputs f1) (outputs f2) || isIntersection (inputs f2) (outputs f1)
   98         isIntersection a b = not $ S.disjoint a b
   99 
  100 isSupportByAccum f
  101     | Just Add{} <- castF f = True
  102     | Just Sub{} <- castF f = True
  103     | Just Neg{} <- castF f = True
  104     | Just Acc{} <- castF f = True
  105     | otherwise = False
  106 
  107 -- | Create Map String (HistoryTree (F v x)), where Key is input label and Value is FU that contain this input label
  108 containerMapCreate fs =
  109     M.unions $
  110         map
  111             ( \f ->
  112                 foldl
  113                     ( \dataMap k ->
  114                         M.insertWith (++) k [f] dataMap
  115                     )
  116                     M.empty
  117                     (S.toList $ inputs f)
  118             )
  119             fs
  120 
  121 -- | Takes container and refactor it, if it can be
  122 optimizeCluster fs = concatMap refactored fs
  123     where
  124         containerMap = containerMapCreate fs
  125 
  126         refactored f =
  127             concatMap
  128                 ( \o ->
  129                     case M.findWithDefault [] o containerMap of
  130                         [] -> []
  131                         matchedFUs -> concatMap (refactorFunction f) matchedFUs
  132                 )
  133                 (S.toList $ outputs f)
  134 
  135 refactorFunction f' f
  136     | Just (Acc lst') <- castF f'
  137     , Just (Acc lst) <- castF f
  138     , let singleOutBool = (1 ==) $ length $ outputs f'
  139           isOutInpIntersect =
  140             any
  141                 ( \case
  142                     Push _ (I v) -> elem v $ outputs f'
  143                     _ -> False
  144                 )
  145                 lst
  146           makeRefactor = singleOutBool && isOutInpIntersect
  147        in makeRefactor =
  148         let subs _ (Push Minus _) (Push Plus v) = Just $ Push Minus v
  149             subs _ (Push Minus _) (Push Minus v) = Just $ Push Plus v
  150             subs _ (Push Plus _) push@(Push _ _) = Just push
  151             subs v _ pull@(Pull _) = deleteFromPull v pull
  152             subs _ _ _ = error "Pull can not be here"
  153 
  154             refactorAcc _ _ (Pull o) = [Pull o]
  155             refactorAcc accList accNew accOld@(Push s i@(I v))
  156                 | elem v $ outputs accNew = mapMaybe (subs v accOld) accList
  157                 | s == Minus = [Push Minus i]
  158                 | s == Plus = [Push Plus i]
  159             refactorAcc _ _ (Push _ (I _)) = undefined
  160          in [packF $ Acc $ concatMap (refactorAcc lst' f') lst]
  161     | Just f1 <- fromAddSub f'
  162     , Just f2 <- fromAddSub f
  163     , (1 ==) $ length $ outputs f' = case refactorFunction f1 f2 of
  164         [fNew] -> [fNew]
  165         _ -> [f, f']
  166     | otherwise = [f, f']
  167 
  168 deleteFromPull v (Pull (O s))
  169     | S.null deleted = Nothing
  170     | otherwise = Just $ Pull $ O deleted
  171     where
  172         deleted = S.delete v s
  173 deleteFromPull _ (Push _ _) = error "delete only Pull"
  174 
  175 fromAddSub f
  176     | Just (Add in1 in2 (O out)) <- castF f =
  177         Just $
  178             acc $
  179                 [Push Plus in1, Push Plus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
  180     | Just (Sub in1 in2 (O out)) <- castF f =
  181         Just $
  182             acc $
  183                 [Push Plus in1, Push Minus in2] ++ [Pull $ O $ S.fromList [o] | o <- S.toList out]
  184     | Just (Neg in1 (O out)) <- castF f =
  185         Just $
  186             acc $
  187                 Push Minus in1 : [Pull $ O $ S.singleton o | o <- S.toList out]
  188     | Just Acc{} <- castF f = Just f
  189     | otherwise = Nothing