| 1 | // Copyright 2018 The Go Authors. All rights reserved. |
|---|---|
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package jsonrpc2_test |
| 6 | |
| 7 | import ( |
| 8 | "context" |
| 9 | "encoding/json" |
| 10 | "flag" |
| 11 | "fmt" |
| 12 | "net" |
| 13 | "path" |
| 14 | "reflect" |
| 15 | "testing" |
| 16 | |
| 17 | "golang.org/x/tools/internal/event/export/eventtest" |
| 18 | "golang.org/x/tools/internal/jsonrpc2" |
| 19 | "golang.org/x/tools/internal/stack/stacktest" |
| 20 | ) |
| 21 | |
| 22 | var logRPC = flag.Bool("logrpc", false, "Enable jsonrpc2 communication logging") |
| 23 | |
| 24 | type callTest struct { |
| 25 | method string |
| 26 | params interface{} |
| 27 | expect interface{} |
| 28 | } |
| 29 | |
| 30 | var callTests = []callTest{ |
| 31 | {"no_args", nil, true}, |
| 32 | {"one_string", "fish", "got:fish"}, |
| 33 | {"one_number", 10, "got:10"}, |
| 34 | {"join", []string{"a", "b", "c"}, "a/b/c"}, |
| 35 | //TODO: expand the test cases |
| 36 | } |
| 37 | |
| 38 | func (test *callTest) newResults() interface{} { |
| 39 | switch e := test.expect.(type) { |
| 40 | case []interface{}: |
| 41 | var r []interface{} |
| 42 | for _, v := range e { |
| 43 | r = append(r, reflect.New(reflect.TypeOf(v)).Interface()) |
| 44 | } |
| 45 | return r |
| 46 | case nil: |
| 47 | return nil |
| 48 | default: |
| 49 | return reflect.New(reflect.TypeOf(test.expect)).Interface() |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | func (test *callTest) verifyResults(t *testing.T, results interface{}) { |
| 54 | if results == nil { |
| 55 | return |
| 56 | } |
| 57 | val := reflect.Indirect(reflect.ValueOf(results)).Interface() |
| 58 | if !reflect.DeepEqual(val, test.expect) { |
| 59 | t.Errorf("%v:Results are incorrect, got %+v expect %+v", test.method, val, test.expect) |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | func TestCall(t *testing.T) { |
| 64 | stacktest.NoLeak(t) |
| 65 | ctx := eventtest.NewContext(context.Background(), t) |
| 66 | for _, headers := range []bool{false, true} { |
| 67 | name := "Plain" |
| 68 | if headers { |
| 69 | name = "Headers" |
| 70 | } |
| 71 | t.Run(name, func(t *testing.T) { |
| 72 | ctx := eventtest.NewContext(ctx, t) |
| 73 | a, b, done := prepare(ctx, t, headers) |
| 74 | defer done() |
| 75 | for _, test := range callTests { |
| 76 | t.Run(test.method, func(t *testing.T) { |
| 77 | ctx := eventtest.NewContext(ctx, t) |
| 78 | results := test.newResults() |
| 79 | if _, err := a.Call(ctx, test.method, test.params, results); err != nil { |
| 80 | t.Fatalf("%v:Call failed: %v", test.method, err) |
| 81 | } |
| 82 | test.verifyResults(t, results) |
| 83 | if _, err := b.Call(ctx, test.method, test.params, results); err != nil { |
| 84 | t.Fatalf("%v:Call failed: %v", test.method, err) |
| 85 | } |
| 86 | test.verifyResults(t, results) |
| 87 | }) |
| 88 | } |
| 89 | }) |
| 90 | } |
| 91 | } |
| 92 | |
| 93 | func prepare(ctx context.Context, t *testing.T, withHeaders bool) (jsonrpc2.Conn, jsonrpc2.Conn, func()) { |
| 94 | // make a wait group that can be used to wait for the system to shut down |
| 95 | aPipe, bPipe := net.Pipe() |
| 96 | a := run(ctx, withHeaders, aPipe) |
| 97 | b := run(ctx, withHeaders, bPipe) |
| 98 | return a, b, func() { |
| 99 | a.Close() |
| 100 | b.Close() |
| 101 | <-a.Done() |
| 102 | <-b.Done() |
| 103 | } |
| 104 | } |
| 105 | |
| 106 | func run(ctx context.Context, withHeaders bool, nc net.Conn) jsonrpc2.Conn { |
| 107 | var stream jsonrpc2.Stream |
| 108 | if withHeaders { |
| 109 | stream = jsonrpc2.NewHeaderStream(nc) |
| 110 | } else { |
| 111 | stream = jsonrpc2.NewRawStream(nc) |
| 112 | } |
| 113 | conn := jsonrpc2.NewConn(stream) |
| 114 | conn.Go(ctx, testHandler(*logRPC)) |
| 115 | return conn |
| 116 | } |
| 117 | |
| 118 | func testHandler(log bool) jsonrpc2.Handler { |
| 119 | return func(ctx context.Context, reply jsonrpc2.Replier, req jsonrpc2.Request) error { |
| 120 | switch req.Method() { |
| 121 | case "no_args": |
| 122 | if len(req.Params()) > 0 { |
| 123 | return reply(ctx, nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)) |
| 124 | } |
| 125 | return reply(ctx, true, nil) |
| 126 | case "one_string": |
| 127 | var v string |
| 128 | if err := json.Unmarshal(req.Params(), &v); err != nil { |
| 129 | return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)) |
| 130 | } |
| 131 | return reply(ctx, "got:"+v, nil) |
| 132 | case "one_number": |
| 133 | var v int |
| 134 | if err := json.Unmarshal(req.Params(), &v); err != nil { |
| 135 | return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)) |
| 136 | } |
| 137 | return reply(ctx, fmt.Sprintf("got:%d", v), nil) |
| 138 | case "join": |
| 139 | var v []string |
| 140 | if err := json.Unmarshal(req.Params(), &v); err != nil { |
| 141 | return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)) |
| 142 | } |
| 143 | return reply(ctx, path.Join(v...), nil) |
| 144 | default: |
| 145 | return jsonrpc2.MethodNotFound(ctx, reply, req) |
| 146 | } |
| 147 | } |
| 148 | } |
| 149 |
Members