From 660e806aee90e3691828e6ce4661849192e5603c Mon Sep 17 00:00:00 2001 From: Jeremy Quirke Date: Sun, 5 Mar 2023 22:09:07 -0800 Subject: [PATCH] Support simultanous name and group tags As per Dig issue: https://github.com/uber-go/dig/issues/380 In order to support Fx feature requests https://github.com/uber-go/fx/issues/998 https://github.com/uber-go/fx/issues/1036 We need to be able to drop the restriction, both in terms of options dig.Name and dig.Group and dig.Out struct annotations on `name` and `group` being mutually exclusive. In a future PR, this can then be exploited to populate value group maps where the 'name' tag becomes the key of a map[string][T] --- decorate.go | 2 +- dig_test.go | 134 +++++++++++++++++++++++++++++++++++++++++++------ provide.go | 6 --- result.go | 95 ++++++++++++++++++++++++----------- result_test.go | 54 +++++++++++--------- 5 files changed, 217 insertions(+), 74 deletions(-) diff --git a/decorate.go b/decorate.go index 8c4105e9..3a7114c5 100644 --- a/decorate.go +++ b/decorate.go @@ -288,7 +288,7 @@ func findResultKeys(r resultList) ([]key, error) { keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group}) case resultObject: for _, f := range innerResult.Fields { - q = append(q, f.Result) + q = append(q, f.Results...) } case resultList: q = append(q, innerResult.Results...) diff --git a/dig_test.go b/dig_test.go index 69b10f9d..ce2829e4 100644 --- a/dig_test.go +++ b/dig_test.go @@ -749,6 +749,53 @@ func TestEndToEndSuccess(t *testing.T) { assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match") }) + t.Run("multiple As with Group and Name", func(t *testing.T) { + c := digtest.New(t) + expectedNames := []string{"inst1", "inst2"} + expectedStrs := []string{"foo", "bar"} + for i, s := range expectedStrs { + s := s + c.RequireProvide(func() *bytes.Buffer { + return bytes.NewBufferString(s) + }, dig.Group("buffs"), dig.Name(expectedNames[i]), + dig.As(new(io.Reader), new(io.Writer))) + } + + type in struct { + dig.In + + Reader1 io.Reader `name:"inst1"` + Reader2 io.Reader `name:"inst2"` + Readers []io.Reader `group:"buffs"` + Writers []io.Writer `group:"buffs"` + } + + var actualStrs []string + var actualStrsName []string + + c.RequireInvoke(func(got in) { + require.Len(t, got.Readers, 2) + buf := make([]byte, 3) + for i, r := range got.Readers { + _, err := r.Read(buf) + require.NoError(t, err) + actualStrs = append(actualStrs, string(buf)) + // put the text back + got.Writers[i].Write(buf) + } + _, err := got.Reader1.Read(buf) + require.NoError(t, err) + actualStrsName = append(actualStrsName, string(buf)) + _, err = got.Reader2.Read(buf) + require.NoError(t, err) + actualStrsName = append(actualStrsName, string(buf)) + require.Len(t, got.Writers, 2) + }) + + assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match") + assert.ElementsMatch(t, actualStrsName, expectedStrs, "names: list of strings provided must match") + }) + t.Run("As same interface", func(t *testing.T) { c := digtest.New(t) c.RequireProvide(func() io.Reader { @@ -1098,6 +1145,48 @@ func TestGroups(t *testing.T) { }) }) + t.Run("values are provided; coexist with name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value int `group:"val"` + } + + type out2 struct { + dig.Out + + Value int `name:"inst1" group:"val"` + } + + provide := func(i int) { + c.RequireProvide(func() out { + return out{Value: i} + }) + } + + provide(1) + provide(2) + provide(3) + + c.RequireProvide(func() out2 { + return out2{Value: 4} + }) + + type in struct { + dig.In + + SingleValue int `name:"inst1"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{1, 2, 3, 4}, i.Values) + assert.Equal(t, 4, i.SingleValue) + }) + }) + t.Run("groups are provided via option", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1122,6 +1211,36 @@ func TestGroups(t *testing.T) { }) }) + t.Run("groups are provided via option; coexist with name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + provide := func(i int) { + c.RequireProvide(func() int { + return i + }, dig.Group("val")) + } + + provide(1) + provide(2) + provide(3) + + c.RequireProvide(func() int { + return 4 + }, dig.Group("val"), dig.Name("inst1")) + + type in struct { + dig.In + + SingleValue int `name:"inst1"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{1, 2, 3, 4}, i.Values) + assert.Equal(t, 4, i.SingleValue) + }) + }) + t.Run("different types may be grouped", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1998,21 +2117,6 @@ func TestAsExpectingOriginalType(t *testing.T) { }) } -func TestProvideIncompatibleOptions(t *testing.T) { - t.Parallel() - - t.Run("group and name", func(t *testing.T) { - c := digtest.New(t) - err := c.Provide(func() io.Reader { - t.Fatal("this function must not be called") - return nil - }, dig.Group("foo"), dig.Name("bar")) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use named values with value groups: "+ - `name:"bar" provided with group:"foo"`) - }) -} - type testStruct struct{} func (testStruct) TestMethod(x int) float64 { return float64(x) } diff --git a/provide.go b/provide.go index 91a4a920..277c1b3d 100644 --- a/provide.go +++ b/provide.go @@ -46,12 +46,6 @@ type provideOptions struct { } func (o *provideOptions) Validate() error { - if len(o.Group) > 0 { - if len(o.Name) > 0 { - return newErrInvalidInput( - fmt.Sprintf("cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group), nil) - } - } // Names must be representable inside a backquoted string. The only // limitation for raw string literals as per diff --git a/result.go b/result.go index 369cd218..936e0dbe 100644 --- a/result.go +++ b/result.go @@ -66,7 +66,7 @@ type resultOptions struct { } // newResult builds a result from the given type. -func newResult(t reflect.Type, opts resultOptions) (result, error) { +func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) { switch { case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType): return nil, newErrInvalidInput(fmt.Sprintf( @@ -81,7 +81,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) { case t.Kind() == reflect.Ptr && IsOut(t.Elem()): return nil, newErrInvalidInput(fmt.Sprintf( "cannot return a pointer to a result object, use a value instead: %v is a pointer to a struct that embeds dig.Out", t), nil) - case len(opts.Group) > 0: + case len(opts.Group) > 0 && !noGroup: g, err := parseGroupString(opts.Group) if err != nil { return nil, newErrInvalidInput( @@ -176,7 +176,9 @@ func walkResult(r result, v resultVisitor) { w := v for _, f := range res.Fields { if v := w.AnnotateWithField(f); v != nil { - walkResult(f.Result, v) + for _, r := range f.Results { + walkResult(r, v) + } } } case resultList: @@ -200,7 +202,7 @@ type resultList struct { // For each item at index i returned by the constructor, resultIndexes[i] // is the index in .Results for the corresponding result object. // resultIndexes[i] is -1 for errors returned by constructors. - resultIndexes []int + resultIndexes [][]int } func (rl resultList) DotResult() []*dot.Result { @@ -216,25 +218,47 @@ func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { rl := resultList{ ctype: ctype, Results: make([]result, 0, numOut), - resultIndexes: make([]int, numOut), + resultIndexes: make([][]int, numOut), } resultIdx := 0 for i := 0; i < numOut; i++ { t := ctype.Out(i) if isError(t) { - rl.resultIndexes[i] = -1 + rl.resultIndexes[i] = append(rl.resultIndexes[i], -1) continue } - r, err := newResult(t, opts) - if err != nil { - return rl, newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err) + addResult := func(nogroup bool) error { + r, err := newResult(t, opts, nogroup) + if err != nil { + return newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err) + } + + rl.Results = append(rl.Results, r) + rl.resultIndexes[i] = append(rl.resultIndexes[i], resultIdx) + resultIdx++ + return nil + } + + // special case, its added as a group and a name using options alone + if len(opts.Name) > 0 && len(opts.Group) > 0 && !IsOut(t) { + // add as a group + if err := addResult(false); err != nil { + return rl, err + } + // add as single + if err := addResult(true); err != nil { + return rl, err + } + return rl, nil + } + + // add as normal + if err := addResult(false); err != nil { + return rl, err } - rl.Results = append(rl.Results, r) - rl.resultIndexes[i] = resultIdx - resultIdx++ } return rl, nil @@ -246,8 +270,14 @@ func (resultList) Extract(containerWriter, bool, reflect.Value) { func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error { for i, v := range values { - if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { - rl.Results[resultIdx].Extract(cw, decorated, v) + isNonErrorResult := false + for _, resultIdx := range rl.resultIndexes[i] { + if resultIdx >= 0 { + rl.Results[resultIdx].Extract(cw, decorated, v) + isNonErrorResult = true + } + } + if isNonErrorResult { continue } @@ -384,7 +414,9 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { - f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) + for _, r := range f.Results { + r.Extract(cw, decorated, v.Field(f.FieldIndex)) + } } } @@ -399,12 +431,16 @@ type resultObjectField struct { // map to results. FieldIndex int - // Result produced by this field. - Result result + // Results produced by this field. + Results []result } func (rof resultObjectField) DotResult() []*dot.Result { - return rof.Result.DotResult() + results := make([]*dot.Result, 0, len(rof.Results)) + for _, r := range rof.Results { + results = append(results, r.DotResult()...) + } + return results } // newResultObjectField(i, f, opts) builds a resultObjectField from the field @@ -414,7 +450,11 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r FieldName: f.Name, FieldIndex: idx, } - + name := f.Tag.Get(_nameTag) + if len(name) > 0 { + // can modify in-place because options are passed-by-value. + opts.Name = name + } var r result switch { case f.PkgPath != "": @@ -427,20 +467,21 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r if err != nil { return rof, err } + rof.Results = append(rof.Results, r) + if len(name) == 0 { + break + } + fallthrough default: var err error - if name := f.Tag.Get(_nameTag); len(name) > 0 { - // can modify in-place because options are passed-by-value. - opts.Name = name - } - r, err = newResult(f.Type, opts) + r, err = newResult(f.Type, opts, false) if err != nil { return rof, err } + rof.Results = append(rof.Results, r) } - rof.Result = r return rof, nil } @@ -493,7 +534,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { Flatten: g.Flatten, Type: f.Type, } - name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { case g.Flatten && f.Type.Kind() != reflect.Slice: @@ -502,9 +542,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { case g.Soft: return rg, newErrInvalidInput(fmt.Sprintf( "cannot use soft with result value groups: soft was used with group %q", rg.Group), nil) - case name != "": - return rg, newErrInvalidInput(fmt.Sprintf( - "cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group), nil) case optional: return rg, newErrInvalidInput("value groups cannot be optional", nil) } diff --git a/result_test.go b/result_test.go index c19db20d..db58ce2b 100644 --- a/result_test.go +++ b/result_test.go @@ -108,7 +108,7 @@ func TestNewResultErrors(t *testing.T) { for _, tt := range tests { give := reflect.TypeOf(tt.give) t.Run(fmt.Sprint(give), func(t *testing.T) { - _, err := newResult(give, resultOptions{}) + _, err := newResult(give, resultOptions{}, false) require.Error(t, err) assert.Contains(t, err.Error(), tt.err) }) @@ -139,12 +139,12 @@ func TestNewResultObject(t *testing.T) { { FieldName: "Reader", FieldIndex: 1, - Result: resultSingle{Type: typeOfReader}, + Results: []result{resultSingle{Type: typeOfReader}}, }, { FieldName: "Writer", FieldIndex: 2, - Result: resultSingle{Type: typeOfWriter}, + Results: []result{resultSingle{Type: typeOfWriter}}, }, }, }, @@ -160,12 +160,12 @@ func TestNewResultObject(t *testing.T) { { FieldName: "A", FieldIndex: 1, - Result: resultSingle{Name: "stream-a", Type: typeOfWriter}, + Results: []result{resultSingle{Name: "stream-a", Type: typeOfWriter}}, }, { FieldName: "B", FieldIndex: 2, - Result: resultSingle{Name: "stream-b", Type: typeOfWriter}, + Results: []result{resultSingle{Name: "stream-b", Type: typeOfWriter}}, }, }, }, @@ -180,7 +180,25 @@ func TestNewResultObject(t *testing.T) { { FieldName: "Writer", FieldIndex: 1, - Result: resultGrouped{Group: "writers", Type: typeOfWriter}, + Results: []result{resultGrouped{Group: "writers", Type: typeOfWriter}}, + }, + }, + }, + { + desc: "group and name tag", + give: struct { + Out + + Writer io.Writer `name:"writer1" group:"writers"` + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Writer", + FieldIndex: 1, + Results: []result{ + resultGrouped{Group: "writers", Type: typeOfWriter}, + resultSingle{Name: "writer1", Type: typeOfWriter}, + }, }, }, }, @@ -229,16 +247,6 @@ func TestNewResultObjectErrors(t *testing.T) { }{}, err: `bad field "Nested"`, }, - { - desc: "group with name should fail", - give: struct { - Out - - Foo string `group:"foo" name:"bar"` - }{}, - err: "cannot use named values with value groups: " + - `name:"bar" provided with group:"foo"`, - }, { desc: "group marked as optional", give: struct { @@ -414,31 +422,31 @@ func TestWalkResult(t *testing.T) { { AnnotateWithField: &ro.Fields[0], Return: fakeResultVisits{ - {Visit: ro.Fields[0].Result}, + {Visit: ro.Fields[0].Results[0]}, }, }, { AnnotateWithField: &ro.Fields[1], Return: fakeResultVisits{ - {Visit: ro.Fields[1].Result}, + {Visit: ro.Fields[1].Results[0]}, }, }, { AnnotateWithField: &ro.Fields[2], Return: fakeResultVisits{ { - Visit: ro.Fields[2].Result, + Visit: ro.Fields[2].Results[0], Return: fakeResultVisits{ { - AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[0], + AnnotateWithField: &ro.Fields[2].Results[0].(resultObject).Fields[0], Return: fakeResultVisits{ - {Visit: ro.Fields[2].Result.(resultObject).Fields[0].Result}, + {Visit: ro.Fields[2].Results[0].(resultObject).Fields[0].Results[0]}, }, }, { - AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[1], + AnnotateWithField: &ro.Fields[2].Results[0].(resultObject).Fields[1], Return: fakeResultVisits{ - {Visit: ro.Fields[2].Result.(resultObject).Fields[1].Result}, + {Visit: ro.Fields[2].Results[0].(resultObject).Fields[1].Results[0]}, }, }, },