never executed always true always false
    1 -- All extensions should be enabled explicitly due to doctest in this module.
    2 {-# LANGUAGE ConstraintKinds #-}
    3 {-# LANGUAGE DeriveGeneric #-}
    4 {-# LANGUAGE FlexibleContexts #-}
    5 {-# LANGUAGE FlexibleInstances #-}
    6 {-# LANGUAGE FunctionalDependencies #-}
    7 {-# LANGUAGE GADTs #-}
    8 {-# LANGUAGE ImportQualifiedPost #-}
    9 {-# LANGUAGE LambdaCase #-}
   10 {-# LANGUAGE NamedFieldPuns #-}
   11 {-# LANGUAGE TypeFamilies #-}
   12 {-# LANGUAGE UndecidableInstances #-}
   13 
   14 {- |
   15 The optimization consists of two parts:
   16 1) Replacing logic functions with lookup tables
   17 2) Searching and merging lookup tables, if possible
   18 
   19 >>> let a = constant 1 ["a"]
   20 >>> let b = constant 2 ["b"]
   21 >>> let c = constant 3 ["c"]
   22 >>> let f1 = logicAnd "a" "b" ["f1"]
   23 >>> let f2 = logicOr "f1" "c" ["f2"]
   24 >>> let loopRes = loop 1 "e" ["f2"]
   25 >>> let fs = [a, b, c, f1, f2, loopRes] :: [F String Int]
   26 >>> optimizeLogicalUnitDecision fs $ head $ optimizeLogicalUnitOptions fs
   27 [const(1) = a,const(2) = b,const(3) = c,loop(1, e) = f2,TruthTable fromList [([False,False,False],False),([False,False,True],True),([False,True,False],False),([False,True,True],True),([True,False,False],False),([True,False,True],True),([True,True,False],True),([True,True,True],True)] [a,b,c] = f2]
   28 -}
   29 module NITTA.Model.Problems.Refactor.OptimizeLogicalUnit (
   30     OptimizeLogicalUnit (..),
   31     OptimizeLogicalUnitProblem (..),
   32 )
   33 where
   34 
   35 import Control.Monad (replicateM)
   36 import Data.Foldable (foldl')
   37 import Data.List qualified as L
   38 import Data.Map qualified as M
   39 import Data.Maybe
   40 import Data.Set qualified as S
   41 import GHC.Generics
   42 import NITTA.Intermediate.Functions
   43 import NITTA.Intermediate.Types
   44 
   45 data OptimizeLogicalUnit v x = OptimizeLogicalUnit
   46     { rOld :: [F v x]
   47     , rNew :: [F v x]
   48     }
   49     deriving (Generic, Show, Eq)
   50 
   51 class OptimizeLogicalUnitProblem u v x | u -> v x where
   52     optimizeLogicalUnitOptions :: u -> [OptimizeLogicalUnit v x]
   53     optimizeLogicalUnitOptions _ = []
   54 
   55     -- | Function takes 'OptimizeLogicalUnit' and modify 'DataFlowGraph'
   56     optimizeLogicalUnitDecision :: u -> OptimizeLogicalUnit v x -> u
   57     optimizeLogicalUnitDecision _ _ = error "not implemented"
   58 
   59 instance (Var v, Val x) => OptimizeLogicalUnitProblem [F v x] v x where
   60     optimizeLogicalUnitOptions fs =
   61         let supportedFunctions = filter isSupportedByLogicalUnit fs
   62 
   63             rNew =
   64                 if not (null supportedFunctions)
   65                     && isOptimizationNeeded supportedFunctions
   66                     then optimizeCluster supportedFunctions fs
   67                     else []
   68             result =
   69                 [ OptimizeLogicalUnit{rOld = supportedFunctions, rNew}
   70                 | not (null rNew) && S.fromList supportedFunctions /= S.fromList rNew
   71                 ]
   72          in result
   73 
   74     optimizeLogicalUnitDecision fs OptimizeLogicalUnit{rOld, rNew} =
   75         deleteExtraLogicalUnits $ (fs L.\\ rOld) <> rNew
   76 
   77 deleteExtraLogicalUnits fs =
   78     L.nub
   79         [ f1
   80         | f1 <- fs
   81         , f2 <- fs
   82         , f1 /= f2
   83         , not $ S.null (variables f1 `S.intersection` variables f2)
   84         ]
   85 
   86 isOptimizationNeeded fs = countLogicalUnits fs > 1 || hasLogicFunctions fs
   87     where
   88         hasLogicFunctions fns = any isSupportedByLogicalUnit fns
   89 
   90         isLogicalUnit f = case castF f of
   91             Just (TruthTable{}) -> True
   92             _ -> False
   93 
   94         countLogicalUnits fns = length $ filter isLogicalUnit fns
   95 
   96 isSupportedByLogicalUnit f
   97     | Just LogicAnd{} <- castF f = True
   98     | Just LogicOr{} <- castF f = True
   99     | Just LogicNot{} <- castF f = True
  100     | otherwise = False
  101 
  102 optimizeCluster allFunctions _ =
  103     let clusters = findMergeClusters allFunctions
  104         mergedLogicalUnits = mapMaybe mergeCluster clusters
  105 
  106         singleFunctions = filter (\f -> isSupportedByLogicalUnit f && S.size (outputs f) /= 1) allFunctions
  107         singleLogicalUnits = mapMaybe convertToLOGICALUNIT singleFunctions
  108 
  109         remainingFunctions = allFunctions L.\\ (concat clusters ++ singleFunctions)
  110      in mergedLogicalUnits ++ singleLogicalUnits ++ remainingFunctions
  111     where
  112         mergeCluster cluster
  113             | isSingleOutputChain cluster = mergeLogicCluster M.empty cluster
  114             | otherwise = Nothing
  115 
  116         convertToLOGICALUNIT f = case castF f of
  117             Just (LogicAnd (I a) (I b) (O out)) ->
  118                 buildCombinedLOGICALUNIT
  119                     [a, b]
  120                     out
  121                     ( \case
  122                         [x, y] -> x && y
  123                         _ -> error "Unexpected pattern"
  124                     )
  125             Just (LogicOr (I a) (I b) (O out)) ->
  126                 buildCombinedLOGICALUNIT
  127                     [a, b]
  128                     out
  129                     ( \case
  130                         [x, y] -> x || y
  131                         _ -> error "Unexpected pattern"
  132                     )
  133             Just (LogicNot (I a) (O out)) ->
  134                 buildCombinedLOGICALUNIT
  135                     [a]
  136                     out
  137                     ( \case
  138                         [x] -> not x
  139                         _ -> error "Unexpected pattern"
  140                     )
  141             _ -> Nothing
  142 
  143 mergeLogicCluster _ fs =
  144     let (inputVars, finalOutput) = analyzeClusterIO fs
  145         evalFn = buildCombinedLogic fs inputVars
  146      in buildCombinedLOGICALUNIT inputVars finalOutput evalFn
  147 
  148 isSingleOutputChain fs =
  149     all (\f -> S.size (outputs f) == 1) fs
  150         && all (== 1) [S.size (outputs (fs !! i) `S.intersection` inputs (fs !! (i + 1))) | i <- [0 .. length fs - 2]]
  151 
  152 analyzeClusterIO fs =
  153     let allInputs = S.unions $ map inputs fs
  154         allOutputs = S.unions $ map outputs fs
  155         externalInputs = S.difference allInputs allOutputs
  156         finalOutput = outputs $ last fs
  157      in (S.toList externalInputs, finalOutput)
  158 
  159 buildCombinedLogic fs inputVars =
  160     let evalCombination comb =
  161             let varMap = M.fromList $ zip inputVars comb
  162                 resultMap = foldl' (\vm f -> applyLogicGate f vm) varMap fs
  163              in resultMap M.! S.elemAt 0 (outputs $ last fs)
  164      in evalCombination
  165 
  166 applyLogicGate f varMap = case castF f of
  167     Just (LogicAnd (I a) (I b) (O out)) ->
  168         case S.toList out of
  169             [outVar] -> M.insert outVar (varMap M.! a && varMap M.! b) varMap
  170             _ -> error "LogicAnd must have exactly one output: 1"
  171     Just (LogicOr (I a) (I b) (O out)) ->
  172         case S.toList out of
  173             [outVar] -> M.insert outVar (varMap M.! a || varMap M.! b) varMap
  174             _ -> error "LogicOr must have exactly one output: 2"
  175     Just (LogicNot (I a) (O out)) ->
  176         case S.toList out of
  177             [outVar] -> M.insert outVar (not $ varMap M.! a) varMap
  178             _ -> error "LogicNot must have exactly one output: 3"
  179     _ -> varMap
  180 
  181 buildCombinedLOGICALUNIT :: (Var v, Val x) => [v] -> S.Set v -> ([Bool] -> Bool) -> Maybe (F v x)
  182 buildCombinedLOGICALUNIT inputVars outputSet evalFn =
  183     let logicalunitInputs = map I inputVars
  184         logicalunitOutput = O outputSet
  185         inputCombinations = replicateM (length inputVars) [False, True]
  186         tbl = M.fromList [(comb, evalFn comb) | comb <- inputCombinations]
  187      in Just $ packF $ TruthTable tbl logicalunitInputs logicalunitOutput
  188 
  189 topSort :: Eq a => [(a, [a])] -> [a]
  190 topSort [] = []
  191 topSort g = case L.partition (null . snd) g of
  192     ([], _) -> []
  193     (ready, rest) ->
  194         map fst ready
  195             ++ topSort
  196                 [ (x, filter (`notElem` readyNodes) ys)
  197                 | (x, ys) <- rest
  198                 ]
  199         where
  200             readyNodes = map fst ready
  201 
  202 findMergeClusters :: Var v => [F v x] -> [[F v x]]
  203 findMergeClusters fs =
  204     let deps = buildDependencyGraph fs
  205         sorted = reverse $ topSort deps
  206         clusters = groupChains sorted
  207      in clusters
  208     where
  209         buildDependencyGraph fns =
  210             [ (f, [g | g <- fns, sharesDependency f g])
  211             | f <- fns
  212             ]
  213 
  214         sharesDependency f g =
  215             not $ S.null (outputs f `S.intersection` inputs g)
  216 
  217         groupChains [] = []
  218         groupChains (x : xs) =
  219             let (chain, rest) = collectChain [x] xs
  220              in chain : groupChains rest
  221             where
  222                 collectChain acc' [] = (acc', [])
  223                 collectChain acc' (y : ys)
  224                     | sharesDependency (last acc') y
  225                         && isSingleOutputChain (acc' ++ [y]) =
  226                         collectChain (acc' ++ [y]) ys
  227                     | otherwise = (acc', y : ys)