Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tuple return values for par exprs in codegen #157

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions regression-tests/tests/foobar1.ssl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
type Pair2 a b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this not there before? And is it being used?

Pair2 a b

putip_ putc x =
if x < 10
putc (x + 48)
Expand Down
1 change: 1 addition & 0 deletions regression-tests/tests/par-ret-vals.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
C D
21 changes: 21 additions & 0 deletions regression-tests/tests/par-ret-vals.ssl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// check that par return values work
// only currently support 2-tuple return values
add a b = a + b

type Pair2 a b
Pair2 a b

printCharTuple2 putc1 p1 =
match p1
Pair2 x1 y1 = putc1 x1
putc1 32
putc1 y1

main cin cout =
let putc c = after 1, cout <- c
wait cout
let putnl _ = putc 10
let z = par add 60 7 // intialize a par expression
add 60 8
printCharTuple2 putc z
putnl 4
137 changes: 79 additions & 58 deletions src/Codegen/Codegen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ import qualified Language.C.Syntax as C
import qualified Common.Compiler as Compiler
import Common.Identifiers (fromId, fromString)

import Control.Monad (foldM, unless)
import Control.Monad (foldM, unless, when)

-- import Control.Monad (foldM, unless)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the commented-out import

import Control.Monad.Except (MonadError (..))
import Control.Monad.State.Lazy (
MonadState,
Expand Down Expand Up @@ -80,26 +82,26 @@ should be computed first, before this information is used to generate the act
struct and enter definitions.
-}
data GenFnState = GenFnState
{ fnName :: I.VarId
-- ^ Function name
, fnParams :: [I.Binder I.Type]
-- ^ Function parameters
, fnRetTy :: I.Type
-- ^ Function return type
, fnBody :: I.Expr I.Type
-- ^ Function body
, fnLocals :: M.Map I.VarId I.Type
-- ^ Function local variables
, fnVars :: M.Map I.VarId C.Exp
-- ^ How to resolve variables
, fnMaxWaits :: Int
-- ^ Number of triggers needed
, fnCases :: Int
-- ^ Yield point counter
, fnFresh :: Int
-- ^ Temporary variable name counter
, fnTypeInfo :: TypegenInfo
-- ^ (User-defined) type information
{ -- | Function name
fnName :: I.VarId
, -- | Function parameters
fnParams :: [I.Binder I.Type]
, -- | Function return type
fnRetTy :: I.Type
, -- | Function body
fnBody :: I.Expr I.Type
, -- | Function local variables
fnLocals :: M.Map I.VarId I.Type
, -- | How to resolve variables
fnVars :: M.Map I.VarId C.Exp
, -- | Number of triggers needed
fnMaxWaits :: Int
, -- | Yield point counter
fnCases :: Int
, -- | Temporary variable name counter
fnFresh :: Int
, -- | (User-defined) type information
fnTypeInfo :: TypegenInfo
}


Expand All @@ -119,21 +121,21 @@ newtype GenFn a = GenFn (StateT GenFnState Compiler.Pass a)


-- | Run a 'GenFn' computation on a procedure.
runGenFn
:: I.VarId
-- ^ Name of procedure
-> [I.Binder I.Type]
-- ^ Names and types of parameters to procedure
-> I.Expr I.Type
-- ^ Body of procedure
-> TypegenInfo
-- ^ Type information
-> [(I.VarId, I.Type)]
-- ^ Other global identifiers
-> GenFn a
-- ^ Translation monad to run
-> Compiler.Pass a
-- ^ Pass on errors to caller
runGenFn ::
-- | Name of procedure
I.VarId ->
-- | Names and types of parameters to procedure
[I.Binder I.Type] ->
-- | Body of procedure
I.Expr I.Type ->
-- | Type information
TypegenInfo ->
-- | Other global identifiers
[(I.VarId, I.Type)] ->
-- | Translation monad to run
GenFn a ->
-- | Pass on errors to caller
Compiler.Pass a
runGenFn name params body typeInfo globals (GenFn tra) =
evalStateT tra $
GenFnState
Expand Down Expand Up @@ -299,10 +301,10 @@ genProgram p = do
++ cdefns
++ genInitProgram (I.programEntry p)
where
genTop
:: TypegenInfo
-> (I.Binder I.Type, I.Expr I.Type)
-> Compiler.Pass ([C.Definition], [C.Definition])
genTop ::
TypegenInfo ->
(I.Binder I.Type, I.Expr I.Type) ->
Compiler.Pass ([C.Definition], [C.Definition])
genTop tinfo (I.BindVar name _, [email protected]{}) =
runGenFn (fromId name) argIds body tinfo tops $ do
(stepDecl, stepDefn) <- genStep
Expand Down Expand Up @@ -696,11 +698,11 @@ genExpr (I.Match s as t) = do
mkBlk :: CIdent -> [C.BlockItem] -> [C.BlockItem]
mkBlk label blk =
[citems|$id:label:;|] ++ blk ++ [citems|goto $id:joinLabel;|]
withAltScope
:: CIdent
-> I.Alt I.Type
-> GenFn [C.BlockItem]
-> GenFn (C.BlockItem, [C.BlockItem])
withAltScope ::
CIdent ->
I.Alt I.Type ->
GenFn [C.BlockItem] ->
GenFn (C.BlockItem, [C.BlockItem])
withAltScope label a@(I.AltData dcon _ _) m = do
destruct <- getsDCon dconDestruct dcon
cas <- getsDCon dconCase dcon
Expand Down Expand Up @@ -728,8 +730,8 @@ genExpr (I.Exception _ t) = do


-- | Generate code for SSM primitive; see 'genExpr' for extended discussion.
genPrim
:: I.Primitive -> [I.Expr I.Type] -> I.Type -> GenFn (C.Exp, [C.BlockItem])
genPrim ::
I.Primitive -> [I.Expr I.Type] -> I.Type -> GenFn (C.Exp, [C.BlockItem])
genPrim I.New [e] refType = do
(val, stms) <- genExpr e
tmp <- genTmp refType
Expand Down Expand Up @@ -762,13 +764,13 @@ genPrim I.After [time, lhs, rhs] _ = do
(timeVal, timeStms) <- genExpr time
(lhsVal, lhsStms) <- genExpr lhs
(rhsVal, rhsStms) <- genExpr rhs
let when = [cexp|$exp:now() + $exp:(unmarshal timeVal)|]
let when' = [cexp|$exp:now() + $exp:(unmarshal timeVal)|]
laterBlock =
[citems|
$items:timeStms
$items:lhsStms
$items:rhsStms
$exp:(later lhsVal when rhsVal);
$exp:(later lhsVal when' rhsVal);
$exp:(drop timeVal);
$exp:(drop rhsVal);
$exp:(drop lhsVal);
Expand All @@ -792,9 +794,9 @@ genPrim I.Par procs _ = do
-- implemented just yet.
-- So, this is currently broken in that side effects inside the arguments
-- of function calls will be evaluated sequentially, which is wrong.
apply
:: (I.Expr I.Type, (C.Exp, C.Exp))
-> GenFn (C.Exp, [C.BlockItem], [C.BlockItem])
apply ::
(I.Expr I.Type, (C.Exp, C.Exp)) ->
GenFn (C.Exp, [C.BlockItem], [C.BlockItem])
apply (I.App fn arg ty, (prio, depth)) = do
(fnExp, fnStms) <- genExpr fn
(argExp, argStms) <- genExpr arg
Expand All @@ -809,12 +811,31 @@ genPrim I.Par procs _ = do
return (ret, fnStms ++ argStms, appStms)
apply (e, _) = do
fail $ "Cannot compile par with non-application expression: " ++ show e
-- given a list of par return vals and their types, wrap the return vals in a tuple
genParRetVal :: [C.Exp] -> [I.Type] -> GenFn (C.Exp, [C.BlockItem])
genParRetVal [] _ = fail "par should have 2 or more return values"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can just put this as a _ wildcard case after the rets@(ret0:ret1:_) case, to consolidate them

genParRetVal [_] _ = fail "par should have 2 or more return values"
genParRetVal rets@(ret0 : ret1 : _) retTys = do
-- TODO: given n ret vals, return an n-ary tuple
when (length rets /= length retTys) $ do fail "lists of return vals and types must be same length"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the fail statement on the line after the do, i.e.,

when (length rets /= length retTys) $ do
  fail "lists of return vals and types must be same length"

let dcon = I.tempTupleId (length rets)
let dty = I.tempTuple retTys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the redundant let here, should just be:

let dcon = ...
    dty = ...

onHeap <- getsDCon dconOnHeap dcon
unless onHeap $ do
fail $ "Cannot handle packed fields yet, for: " ++ show dcon
construct <- getsDCon dconConstruct dcon
destruct <- getsDCon dconDestruct dcon
tmp <- genTmp dty
let alloc = [[citem|$exp:tmp = $exp:construct;|]]
initField y i = [citem|$exp:(destruct i tmp) = $exp:y;|]
initFields = zipWith initField [ret0, ret1] [0 ..] -- puts first two return vals in a 2-tuple for now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only the first two values? Why not zipWith initField rets [0..]?

return (tmp, alloc ++ initFields)
Comment on lines +814 to +832
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creates a 2-tuple with the first two ret vals of par as its arguments


(_rets, befores, activates) <- unzip3 <$> mapM apply (zip procs parArgs)
yield <- genYield
let parRetVal = unit -- TODO: return tuple of values
(parRetVal, tupleStms) <- genParRetVal _rets (I.extract <$> procs)
return
(parRetVal, checkNewDepth ++ concat befores ++ concat activates ++ yield)
(parRetVal, checkNewDepth ++ concat befores ++ concat activates ++ yield ++ tupleStms)
genPrim I.Wait vars _ = do
(varVals, varStms) <- unzip <$> mapM genExpr vars
maxWait $ length varVals
Expand Down Expand Up @@ -866,8 +887,8 @@ genLiteralRaw I.LitEvent = [cexp|1|]


-- | Generate C expression for SSM primitive operation.
genPrimOp
:: I.PrimOp -> [I.Expr I.Type] -> I.Type -> GenFn (C.Exp, [C.BlockItem])
genPrimOp ::
I.PrimOp -> [I.Expr I.Type] -> I.Type -> GenFn (C.Exp, [C.BlockItem])
genPrimOp I.PrimAdd [lhs, rhs] _ = do
((lhsVal, rhsVal), stms) <-
first (bimap unmarshal unmarshal) <$> genBinop lhs rhs
Expand Down Expand Up @@ -946,8 +967,8 @@ genPrimOp _ _ _ = fail "Unsupported PrimOp or wrong number of arguments"


-- | Helper for sequencing across binary operations.
genBinop
:: I.Expr I.Type -> I.Expr I.Type -> GenFn ((C.Exp, C.Exp), [C.BlockItem])
genBinop ::
I.Expr I.Type -> I.Expr I.Type -> GenFn ((C.Exp, C.Exp), [C.BlockItem])
genBinop lhs rhs = do
(lhsVal, lhsStms) <- genExpr lhs
(rhsVal, rhsStms) <- genExpr rhs
Expand Down
4 changes: 2 additions & 2 deletions src/IR/Constraint/Canonical.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import IR.Types.Type (
Type (..),
builtinKinds,
foldArrow,
tupleId,
tempTupleId,
unAnnotations,
unfoldArrow,
)
Expand Down Expand Up @@ -141,7 +141,7 @@ pattern U8 = TCon "UInt8" []

-- | Construct a builtin tuple type out of a list of at least 2 types.
tuple :: [Type] -> Type
tuple ts = TCon (tupleId $ length ts) ts
tuple ts = TCon (tempTupleId $ length ts) ts


-- | ANNOTATION
Expand Down
12 changes: 12 additions & 0 deletions src/IR/Types/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ isNum :: Type -> Bool
isNum t = isInt t || isUInt t


{- | Construct a builtin tuple type out of a list of at least 2 types

Uses tempTupleId instead of tupleId
-}
tempTuple :: [Type] -> Type
tempTuple ts = TCon (tempTupleId $ length ts) ts


-- | Construct a builtin tuple type out of a list of at least 2 types.
tuple :: [Type] -> Type
tuple ts
Expand All @@ -276,6 +284,10 @@ tupleId i
| otherwise = error $ "Cannot create tuple of arity: " ++ show (toInteger i)


{- | Construct the type constructor of a builtin tuple of given arity (>= 2).

Instead of using parentheses, use temporary data constructor name
-}
tempTupleId :: (Integral i, Identifiable v) => i -> v
tempTupleId i
| i >= 2 = fromString $ "Pair" ++ show (toInteger i)
Expand Down
6 changes: 5 additions & 1 deletion test/semant/Tests/TypeInferSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ spec = do
describe "larger, full-program test cases" $ do
it "typechecks an unannotated, monomorphic program" $ do
[here|
type Pair2 a b
Pair2 a b
toggle(led : &Int) -> () =
(led: &Int) <- (1 - deref (led: &Int): Int)
slow(led : &Int) -> () =
Expand All @@ -357,10 +359,12 @@ spec = do
((toggle: &Int -> ()) (led: &Int): ())
after 20 , (e2: &()) <- ()
wait (e2: &())
main(led : &Int) -> ((), ()) =
main(led : &Int) -> (Pair2 () ()) =
par ((slow: &Int -> ()) (led: &Int): ())
((fast: &Int -> ()) (led: &Int): ())
|] `typeChecksAs` [here|
type Pair2 a b
Pair2 a b
toggle(led) =
led <- 1 - deref led
slow(led) =
Expand Down