| 1 | // Copyright 2021 The Go Authors. All rights reserved. |
|---|---|
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package vta |
| 6 | |
| 7 | import ( |
| 8 | "go/types" |
| 9 | |
| 10 | "golang.org/x/tools/go/callgraph" |
| 11 | "golang.org/x/tools/go/ssa" |
| 12 | "golang.org/x/tools/internal/typeparams" |
| 13 | ) |
| 14 | |
| 15 | func canAlias(n1, n2 node) bool { |
| 16 | return isReferenceNode(n1) && isReferenceNode(n2) |
| 17 | } |
| 18 | |
| 19 | func isReferenceNode(n node) bool { |
| 20 | if _, ok := n.(nestedPtrInterface); ok { |
| 21 | return true |
| 22 | } |
| 23 | if _, ok := n.(nestedPtrFunction); ok { |
| 24 | return true |
| 25 | } |
| 26 | |
| 27 | if _, ok := n.Type().(*types.Pointer); ok { |
| 28 | return true |
| 29 | } |
| 30 | |
| 31 | return false |
| 32 | } |
| 33 | |
| 34 | // hasInFlow checks if a concrete type can flow to node `n`. |
| 35 | // Returns yes iff the type of `n` satisfies one the following: |
| 36 | // 1. is an interface |
| 37 | // 2. is a (nested) pointer to interface (needed for, say, |
| 38 | // slice elements of nested pointers to interface type) |
| 39 | // 3. is a function type (needed for higher-order type flow) |
| 40 | // 4. is a (nested) pointer to function (needed for, say, |
| 41 | // slice elements of nested pointers to function type) |
| 42 | // 5. is a global Recover or Panic node |
| 43 | func hasInFlow(n node) bool { |
| 44 | if _, ok := n.(panicArg); ok { |
| 45 | return true |
| 46 | } |
| 47 | if _, ok := n.(recoverReturn); ok { |
| 48 | return true |
| 49 | } |
| 50 | |
| 51 | t := n.Type() |
| 52 | |
| 53 | if i := interfaceUnderPtr(t); i != nil { |
| 54 | return true |
| 55 | } |
| 56 | if f := functionUnderPtr(t); f != nil { |
| 57 | return true |
| 58 | } |
| 59 | |
| 60 | return types.IsInterface(t) || isFunction(t) |
| 61 | } |
| 62 | |
| 63 | func isFunction(t types.Type) bool { |
| 64 | _, ok := t.Underlying().(*types.Signature) |
| 65 | return ok |
| 66 | } |
| 67 | |
| 68 | // interfaceUnderPtr checks if type `t` is a potentially nested |
| 69 | // pointer to interface and if yes, returns the interface type. |
| 70 | // Otherwise, returns nil. |
| 71 | func interfaceUnderPtr(t types.Type) types.Type { |
| 72 | seen := make(map[types.Type]bool) |
| 73 | var visit func(types.Type) types.Type |
| 74 | visit = func(t types.Type) types.Type { |
| 75 | if seen[t] { |
| 76 | return nil |
| 77 | } |
| 78 | seen[t] = true |
| 79 | |
| 80 | p, ok := t.Underlying().(*types.Pointer) |
| 81 | if !ok { |
| 82 | return nil |
| 83 | } |
| 84 | |
| 85 | if types.IsInterface(p.Elem()) { |
| 86 | return p.Elem() |
| 87 | } |
| 88 | |
| 89 | return visit(p.Elem()) |
| 90 | } |
| 91 | return visit(t) |
| 92 | } |
| 93 | |
| 94 | // functionUnderPtr checks if type `t` is a potentially nested |
| 95 | // pointer to function type and if yes, returns the function type. |
| 96 | // Otherwise, returns nil. |
| 97 | func functionUnderPtr(t types.Type) types.Type { |
| 98 | seen := make(map[types.Type]bool) |
| 99 | var visit func(types.Type) types.Type |
| 100 | visit = func(t types.Type) types.Type { |
| 101 | if seen[t] { |
| 102 | return nil |
| 103 | } |
| 104 | seen[t] = true |
| 105 | |
| 106 | p, ok := t.Underlying().(*types.Pointer) |
| 107 | if !ok { |
| 108 | return nil |
| 109 | } |
| 110 | |
| 111 | if isFunction(p.Elem()) { |
| 112 | return p.Elem() |
| 113 | } |
| 114 | |
| 115 | return visit(p.Elem()) |
| 116 | } |
| 117 | return visit(t) |
| 118 | } |
| 119 | |
| 120 | // sliceArrayElem returns the element type of type `t` that is |
| 121 | // expected to be a (pointer to) array, slice or string, consistent with |
| 122 | // the ssa.Index and ssa.IndexAddr instructions. Panics otherwise. |
| 123 | func sliceArrayElem(t types.Type) types.Type { |
| 124 | switch u := t.Underlying().(type) { |
| 125 | case *types.Pointer: |
| 126 | return u.Elem().Underlying().(*types.Array).Elem() |
| 127 | case *types.Array: |
| 128 | return u.Elem() |
| 129 | case *types.Slice: |
| 130 | return u.Elem() |
| 131 | case *types.Basic: |
| 132 | return types.Typ[types.Byte] |
| 133 | case *types.Interface: // type param. |
| 134 | terms, err := typeparams.InterfaceTermSet(u) |
| 135 | if err != nil || len(terms) == 0 { |
| 136 | panic(t) |
| 137 | } |
| 138 | return sliceArrayElem(terms[0].Type()) // Element types must match. |
| 139 | default: |
| 140 | panic(t) |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | // siteCallees computes a set of callees for call site `c` given program `callgraph`. |
| 145 | func siteCallees(c ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function { |
| 146 | var matches []*ssa.Function |
| 147 | |
| 148 | node := callgraph.Nodes[c.Parent()] |
| 149 | if node == nil { |
| 150 | return nil |
| 151 | } |
| 152 | |
| 153 | for _, edge := range node.Out { |
| 154 | if edge.Site == c { |
| 155 | matches = append(matches, edge.Callee.Func) |
| 156 | } |
| 157 | } |
| 158 | return matches |
| 159 | } |
| 160 | |
| 161 | func canHaveMethods(t types.Type) bool { |
| 162 | if _, ok := t.(*types.Named); ok { |
| 163 | return true |
| 164 | } |
| 165 | |
| 166 | u := t.Underlying() |
| 167 | switch u.(type) { |
| 168 | case *types.Interface, *types.Signature, *types.Struct: |
| 169 | return true |
| 170 | default: |
| 171 | return false |
| 172 | } |
| 173 | } |
| 174 | |
| 175 | // calls returns the set of call instructions in `f`. |
| 176 | func calls(f *ssa.Function) []ssa.CallInstruction { |
| 177 | var calls []ssa.CallInstruction |
| 178 | for _, bl := range f.Blocks { |
| 179 | for _, instr := range bl.Instrs { |
| 180 | if c, ok := instr.(ssa.CallInstruction); ok { |
| 181 | calls = append(calls, c) |
| 182 | } |
| 183 | } |
| 184 | } |
| 185 | return calls |
| 186 | } |
| 187 | |
| 188 | // intersect produces an intersection of functions in `fs1` and `fs2`. |
| 189 | func intersect(fs1, fs2 []*ssa.Function) []*ssa.Function { |
| 190 | m := make(map[*ssa.Function]bool) |
| 191 | for _, f := range fs1 { |
| 192 | m[f] = true |
| 193 | } |
| 194 | |
| 195 | var res []*ssa.Function |
| 196 | for _, f := range fs2 { |
| 197 | if m[f] { |
| 198 | res = append(res, f) |
| 199 | } |
| 200 | } |
| 201 | return res |
| 202 | } |
| 203 |
Members