From f6f2469d4d73966eabbf366a981e161f350a43e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20Conde=20G=C3=B3mez?= Date: Tue, 11 May 2021 10:24:46 +0200 Subject: [PATCH] fix: HCL "index" function now actually returns the index of the element (#11008) --- hcl2template/function/index.go | 58 +++++++++++++ hcl2template/function/index_test.go | 127 ++++++++++++++++++++++++++++ hcl2template/functions.go | 2 +- 3 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 hcl2template/function/index.go create mode 100644 hcl2template/function/index_test.go diff --git a/hcl2template/function/index.go b/hcl2template/function/index.go new file mode 100644 index 000000000..0cb55a747 --- /dev/null +++ b/hcl2template/function/index.go @@ -0,0 +1,58 @@ +package function + +import ( + "errors" + + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/function" + "github.com/zclconf/go-cty/cty/function/stdlib" +) + +// IndexFunc constructs a function that finds the element index for a given value in a list. +var IndexFunc = function.New(&function.Spec{ + Params: []function.Parameter{ + { + Name: "list", + Type: cty.DynamicPseudoType, + }, + { + Name: "value", + Type: cty.DynamicPseudoType, + }, + }, + Type: function.StaticReturnType(cty.Number), + Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { + if !(args[0].Type().IsListType() || args[0].Type().IsTupleType()) { + return cty.NilVal, errors.New("argument must be a list or tuple") + } + + if !args[0].IsKnown() { + return cty.UnknownVal(cty.Number), nil + } + + if args[0].LengthInt() == 0 { // Easy path + return cty.NilVal, errors.New("cannot search an empty list") + } + + for it := args[0].ElementIterator(); it.Next(); { + i, v := it.Element() + eq, err := stdlib.Equal(v, args[1]) + if err != nil { + return cty.NilVal, err + } + if !eq.IsKnown() { + return cty.UnknownVal(cty.Number), nil + } + if eq.True() { + return i, nil + } + } + return cty.NilVal, errors.New("item not found") + + }, +}) + +// Index finds the element index for a given value in a list. +func Index(list, value cty.Value) (cty.Value, error) { + return IndexFunc.Call([]cty.Value{list, value}) +} diff --git a/hcl2template/function/index_test.go b/hcl2template/function/index_test.go new file mode 100644 index 000000000..5695cde7d --- /dev/null +++ b/hcl2template/function/index_test.go @@ -0,0 +1,127 @@ +package function + +import ( + "fmt" + "testing" + + "github.com/zclconf/go-cty/cty" +) + +func TestIndex(t *testing.T) { + tests := []struct { + List cty.Value + Value cty.Value + Want cty.Value + Err bool + }{ + { + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + }), + cty.StringVal("a"), + cty.NumberIntVal(0), + false, + }, + { + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.UnknownVal(cty.String), + }), + cty.StringVal("a"), + cty.NumberIntVal(0), + false, + }, + { + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + }), + cty.StringVal("b"), + cty.NumberIntVal(1), + false, + }, + { + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + }), + cty.StringVal("z"), + cty.NilVal, + true, + }, + { + cty.ListVal([]cty.Value{ + cty.StringVal("1"), + cty.StringVal("2"), + cty.StringVal("3"), + }), + cty.NumberIntVal(1), + cty.NumberIntVal(0), + true, + }, + { + cty.ListVal([]cty.Value{ + cty.NumberIntVal(1), + cty.NumberIntVal(2), + cty.NumberIntVal(3), + }), + cty.NumberIntVal(2), + cty.NumberIntVal(1), + false, + }, + { + cty.ListVal([]cty.Value{ + cty.NumberIntVal(1), + cty.NumberIntVal(2), + cty.NumberIntVal(3), + }), + cty.NumberIntVal(4), + cty.NilVal, + true, + }, + { + cty.ListVal([]cty.Value{ + cty.NumberIntVal(1), + cty.NumberIntVal(2), + cty.NumberIntVal(3), + }), + cty.StringVal("1"), + cty.NumberIntVal(0), + true, + }, + { + cty.TupleVal([]cty.Value{ + cty.NumberIntVal(1), + cty.NumberIntVal(2), + cty.NumberIntVal(3), + }), + cty.NumberIntVal(1), + cty.NumberIntVal(0), + false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("index(%#v, %#v)", test.List, test.Value), func(t *testing.T) { + got, err := Index(test.List, test.Value) + + if test.Err { + if err == nil { + t.Fatal("succeeded; want error") + } + return + } else if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !got.RawEquals(test.Want) { + t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.Want) + } + }) + } +} diff --git a/hcl2template/functions.go b/hcl2template/functions.go index 3ac79cfec..a85f414f0 100644 --- a/hcl2template/functions.go +++ b/hcl2template/functions.go @@ -63,7 +63,7 @@ func Functions(basedir string) map[string]function.Function { "formatdate": stdlib.FormatDateFunc, "formatlist": stdlib.FormatListFunc, "indent": stdlib.IndentFunc, - "index": stdlib.IndexFunc, + "index": pkrfunction.IndexFunc, // stdlib.IndexFunc is not compatible "join": stdlib.JoinFunc, "jsondecode": stdlib.JSONDecodeFunc, "jsonencode": stdlib.JSONEncodeFunc,