diff --git a/interpreter/expression.go b/interpreter/expression.go index 4ab9cb74..199d92c0 100644 --- a/interpreter/expression.go +++ b/interpreter/expression.go @@ -36,7 +36,7 @@ func (i *Interpreter) IdentValue(val string, withCondition bool) (value.Value, e } else if _, ok := i.ctx.Ratecounters[val]; ok { return &value.Ident{Value: val, Literal: true}, nil } else if strings.HasPrefix(val, "var.") { - if v, err := i.localVars.Get(val); err != nil { + if v, err := i.StackPointer.Locals.Get(val); err != nil { return value.Null, errors.WithStack(err) } else { return v, nil diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 2170582b..0364b9c5 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -22,10 +22,16 @@ import ( "github.com/ysugimoto/falco/parser" ) +const MaxStackDepth = 100 + +type StackFrame struct { + Locals variable.LocalVariables + Subroutine *ast.SubroutineDeclaration +} + type Interpreter struct { - vars variable.Variable - localVars variable.LocalVariables - lock sync.Mutex + vars variable.Variable + lock sync.Mutex options []context.Option @@ -36,15 +42,20 @@ type Interpreter struct { IdentResolver func(v string) value.Value TestingState State + + Stack []*StackFrame + StackPointer *StackFrame } func New(options ...context.Option) *Interpreter { + stack := []*StackFrame{{Locals: variable.LocalVariables{}}} return &Interpreter{ options: options, cache: cache.New(), - localVars: variable.LocalVariables{}, Debugger: DefaultDebugger{}, TestingState: NONE, + Stack: stack, + StackPointer: stack[0], } } diff --git a/interpreter/statement.go b/interpreter/statement.go index d8c473c0..40d51df4 100644 --- a/interpreter/statement.go +++ b/interpreter/statement.go @@ -208,7 +208,7 @@ func (i *Interpreter) ProcessBlockStatement( } func (i *Interpreter) ProcessDeclareStatement(stmt *ast.DeclareStatement) error { - return i.localVars.Declare(stmt.Name.Value, stmt.ValueType.Value) + return i.StackPointer.Locals.Declare(stmt.Name.Value, stmt.ValueType.Value) } func (i *Interpreter) ProcessReturnStatement(stmt *ast.ReturnStatement) State { @@ -225,7 +225,7 @@ func (i *Interpreter) ProcessSetStatement(stmt *ast.SetStatement) error { } if strings.HasPrefix(stmt.Ident.Value, "var.") { - err = i.localVars.Set(stmt.Ident.Value, stmt.Operator.Operator, right) + err = i.StackPointer.Locals.Set(stmt.Ident.Value, stmt.Operator.Operator, right) } else { err = i.vars.Set(i.ctx.Scope, stmt.Ident.Value, stmt.Operator.Operator, right) } @@ -264,7 +264,7 @@ func (i *Interpreter) ProcessAddStatement(stmt *ast.AddStatement) error { func (i *Interpreter) ProcessUnsetStatement(stmt *ast.UnsetStatement) error { var err error if strings.HasPrefix(stmt.Ident.Value, "var.") { - err = i.localVars.Unset(stmt.Ident.Value) + err = i.StackPointer.Locals.Unset(stmt.Ident.Value) } else { err = i.vars.Unset(i.ctx.Scope, stmt.Ident.Value) } @@ -279,7 +279,7 @@ func (i *Interpreter) ProcessRemoveStatement(stmt *ast.RemoveStatement) error { // Alias of unset var err error if strings.HasPrefix(stmt.Ident.Value, "var.") { - err = i.localVars.Unset(stmt.Ident.Value) + err = i.StackPointer.Locals.Unset(stmt.Ident.Value) } else { err = i.vars.Unset(i.ctx.Scope, stmt.Ident.Value) } diff --git a/interpreter/statement_test.go b/interpreter/statement_test.go index 0475d07a..76b7c135 100644 --- a/interpreter/statement_test.go +++ b/interpreter/statement_test.go @@ -95,7 +95,7 @@ func TestDeclareStatement(t *testing.T) { for _, tt := range tests { ip := New(nil) - err := ip.localVars.Declare(tt.decl.Name.Value, tt.decl.ValueType.Value) + err := ip.StackPointer.Locals.Declare(tt.decl.Name.Value, tt.decl.ValueType.Value) if err != nil { if !tt.isError { t.Errorf("%s: unexpected error returned: %s", tt.name, err) @@ -103,7 +103,7 @@ func TestDeclareStatement(t *testing.T) { continue } - v, err := ip.localVars.Get(tt.decl.Name.Value) + v, err := ip.StackPointer.Locals.Get(tt.decl.Name.Value) if err != nil { t.Errorf("%s: %s varible must be declared: %s", tt.name, tt.decl.Name.Value, err) continue @@ -194,7 +194,7 @@ func TestSetStatement(t *testing.T) { for _, tt := range tests { ip := New(nil) - if err := ip.localVars.Declare("var.foo", "INTEGER"); err != nil { + if err := ip.StackPointer.Locals.Declare("var.foo", "INTEGER"); err != nil { t.Errorf("%s: unexpected error returned: %s", tt.name, err) } diff --git a/interpreter/subroutine.go b/interpreter/subroutine.go index 18b96c24..48e3e800 100644 --- a/interpreter/subroutine.go +++ b/interpreter/subroutine.go @@ -12,20 +12,50 @@ import ( "github.com/ysugimoto/falco/interpreter/variable" ) -func (i *Interpreter) ProcessSubroutine(sub *ast.SubroutineDeclaration, ds DebugState) (State, error) { - i.process.Flows = append(i.process.Flows, process.NewFlow(i.ctx, process.WithSubroutine(sub))) +func (i *Interpreter) subroutineInStack(sub *ast.SubroutineDeclaration) bool { + for _, s := range i.Stack { + if s.Subroutine == sub { + return true + } + } + return false +} - // Store the current values and restore after subroutine has ended - regex := i.ctx.RegexMatchedValues - local := i.localVars - i.ctx.RegexMatchedValues = make(map[string]*value.String) - i.localVars = variable.LocalVariables{} +func (i *Interpreter) pushStackFrame(sub *ast.SubroutineDeclaration) error { + sf := &StackFrame{ + Locals: variable.LocalVariables{}, + Subroutine: sub, + } + i.Stack = append(i.Stack, sf) + if len(i.Stack) > MaxStackDepth { + return errors.WithStack(exception.Runtime(&sub.Token, "max stack depth exceeded")) + } + i.StackPointer = sf + return nil +} + +func (i *Interpreter) popStackFrame() { + var sf *StackFrame + sf, i.Stack = i.Stack[len(i.Stack)-1], i.Stack[:len(i.Stack)-1] + if len(i.Stack) > 0 { + i.StackPointer = i.Stack[len(i.Stack)-1] + } else { + i.StackPointer = nil + } + i.ctx.SubroutineCalls[sf.Subroutine.Name.Value]++ +} - defer func() { - i.ctx.RegexMatchedValues = regex - i.localVars = local - i.ctx.SubroutineCalls[sub.Name.Value]++ - }() +func (i *Interpreter) ProcessSubroutine(sub *ast.SubroutineDeclaration, ds DebugState) (State, error) { + if i.subroutineInStack(sub) { + return NONE, errors.WithStack( + errors.Errorf("Recursion detected, subroutine %s already in stack", sub.Name.Value), + ) + } + i.process.Flows = append(i.process.Flows, process.NewFlow(i.ctx, sub)) + if err := i.pushStackFrame(sub); err != nil { + return NONE, errors.WithStack(err) + } + defer i.popStackFrame() // Try to extract fastly reserved subroutine macro if err := i.extractBoilerplateMacro(sub); err != nil { @@ -44,19 +74,16 @@ func (i *Interpreter) ProcessSubroutine(sub *ast.SubroutineDeclaration, ds Debug // nolint: gocognit func (i *Interpreter) ProcessFunctionSubroutine(sub *ast.SubroutineDeclaration, ds DebugState) (value.Value, State, error) { - i.process.Flows = append(i.process.Flows, process.NewFlow(i.ctx, process.WithSubroutine(sub))) - - // Store the current values and restore after subroutine has ended - regex := i.ctx.RegexMatchedValues - local := i.localVars - i.ctx.RegexMatchedValues = make(map[string]*value.String) - i.localVars = variable.LocalVariables{} - - defer func() { - i.ctx.RegexMatchedValues = regex - i.localVars = local - i.ctx.SubroutineCalls[sub.Name.Value]++ - }() + if i.subroutineInStack(sub) { + return value.Null, NONE, errors.WithStack( + errors.Errorf("Recursion detected, subroutine %s already in stack", sub.Name.Value), + ) + } + i.process.Flows = append(i.process.Flows, process.NewFlow(i.ctx, sub)) + if err := i.pushStackFrame(sub); err != nil { + return value.Null, NONE, errors.WithStack(err) + } + defer i.popStackFrame() var err error var debugState DebugState = ds diff --git a/interpreter/subroutine_test.go b/interpreter/subroutine_test.go index b1350bf9..867bad0c 100644 --- a/interpreter/subroutine_test.go +++ b/interpreter/subroutine_test.go @@ -3,8 +3,12 @@ package interpreter import ( "testing" + "github.com/ysugimoto/falco/ast" "github.com/ysugimoto/falco/interpreter/context" + "github.com/ysugimoto/falco/interpreter/process" "github.com/ysugimoto/falco/interpreter/value" + "github.com/ysugimoto/falco/lexer" + "github.com/ysugimoto/falco/parser" ) func TestSubroutine(t *testing.T) { @@ -44,6 +48,17 @@ func TestSubroutine(t *testing.T) { }`, isError: true, }, + { + name: "Recursion produces an error not a panic", + vcl: `sub func { + call func(); + } + + sub vcl_recv { + call func(); + }`, + isError: true, + }, } for _, tt := range tests { @@ -58,6 +73,7 @@ func TestFunctionSubroutine(t *testing.T) { name string vcl string assertions map[string]value.Value + isError bool }{ { name: "Functional subroutine returns a value", @@ -134,11 +150,47 @@ func TestFunctionSubroutine(t *testing.T) { "req.http.X-Int-Value": &value.String{Value: "2"}, }, }, + { + name: "Recursion produces an error not a panic", + vcl: `sub func STRING { + return func(); + } + + sub vcl_recv { + set req.http.foo = func(); + }`, + isError: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertInterpreter(t, tt.vcl, context.RecvScope, tt.assertions, false) + assertInterpreter(t, tt.vcl, context.RecvScope, tt.assertions, tt.isError) }) } } + +func TestMaxStackDepth(t *testing.T) { + vcl, err := parser.New(lexer.NewFromString(`sub test STRING { return ""; }`)).ParseVCL() + if err != nil { + t.Errorf("VCL parser error: %s", err) + return + } + ip := New() + ip.ctx = context.New() + ip.process = process.New() + if err = ip.ProcessDeclarations(vcl.Statements); err != nil { + t.Errorf("Failed to process statement: %s", err) + return + } + for i := 0; i < MaxStackDepth; i++ { + ip.Stack = append(ip.Stack, &StackFrame{Subroutine: &ast.SubroutineDeclaration{}}) + } + n := ast.FunctionCallExpression{ + Function: &ast.Ident{Value: "test"}, + } + _, err = ip.ProcessFunctionCallExpression(&n, false) + if err == nil { + t.Error("Expected error got nil") + } +} diff --git a/interpreter/variable/all.go b/interpreter/variable/all.go index 2961e172..f1964ff4 100644 --- a/interpreter/variable/all.go +++ b/interpreter/variable/all.go @@ -636,6 +636,7 @@ func (v *AllScopeVariables) getFromRegex(name string) value.Value { if val, ok := v.ctx.RegexMatchedValues[match[1]]; ok { return val } + return &value.String{IsNotSet: true} } // HTTP request header matching