| 1 | // Copyright 2022 The Go Authors. All rights reserved. |
|---|---|
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | package ssa |
| 5 | |
| 6 | import ( |
| 7 | "fmt" |
| 8 | "go/types" |
| 9 | |
| 10 | "golang.org/x/tools/internal/typeparams" |
| 11 | ) |
| 12 | |
| 13 | // Type substituter for a fixed set of replacement types. |
| 14 | // |
| 15 | // A nil *subster is an valid, empty substitution map. It always acts as |
| 16 | // the identity function. This allows for treating parameterized and |
| 17 | // non-parameterized functions identically while compiling to ssa. |
| 18 | // |
| 19 | // Not concurrency-safe. |
| 20 | type subster struct { |
| 21 | // TODO(zpavlinovic): replacements can contain type params |
| 22 | // when generating instances inside of a generic function body. |
| 23 | replacements map[*typeparams.TypeParam]types.Type // values should contain no type params |
| 24 | cache map[types.Type]types.Type // cache of subst results |
| 25 | ctxt *typeparams.Context |
| 26 | debug bool // perform extra debugging checks |
| 27 | // TODO(taking): consider adding Pos |
| 28 | } |
| 29 | |
| 30 | // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. |
| 31 | // targs should not contain any types in tparams. |
| 32 | func makeSubster(ctxt *typeparams.Context, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster { |
| 33 | assert(tparams.Len() == len(targs), "makeSubster argument count must match") |
| 34 | |
| 35 | subst := &subster{ |
| 36 | replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()), |
| 37 | cache: make(map[types.Type]types.Type), |
| 38 | ctxt: ctxt, |
| 39 | debug: debug, |
| 40 | } |
| 41 | for i := 0; i < tparams.Len(); i++ { |
| 42 | subst.replacements[tparams.At(i)] = targs[i] |
| 43 | } |
| 44 | if subst.debug { |
| 45 | if err := subst.wellFormed(); err != nil { |
| 46 | panic(err) |
| 47 | } |
| 48 | } |
| 49 | return subst |
| 50 | } |
| 51 | |
| 52 | // wellFormed returns an error if subst was not properly initialized. |
| 53 | func (subst *subster) wellFormed() error { |
| 54 | if subst == nil || len(subst.replacements) == 0 { |
| 55 | return nil |
| 56 | } |
| 57 | // Check that all of the type params do not appear in the arguments. |
| 58 | s := make(map[types.Type]bool, len(subst.replacements)) |
| 59 | for tparam := range subst.replacements { |
| 60 | s[tparam] = true |
| 61 | } |
| 62 | for _, r := range subst.replacements { |
| 63 | if reaches(r, s) { |
| 64 | return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements) |
| 65 | } |
| 66 | } |
| 67 | return nil |
| 68 | } |
| 69 | |
| 70 | // typ returns the type of t with the type parameter tparams[i] substituted |
| 71 | // for the type targs[i] where subst was created using tparams and targs. |
| 72 | func (subst *subster) typ(t types.Type) (res types.Type) { |
| 73 | if subst == nil { |
| 74 | return t // A nil subst is type preserving. |
| 75 | } |
| 76 | if r, ok := subst.cache[t]; ok { |
| 77 | return r |
| 78 | } |
| 79 | defer func() { |
| 80 | subst.cache[t] = res |
| 81 | }() |
| 82 | |
| 83 | // fall through if result r will be identical to t, types.Identical(r, t). |
| 84 | switch t := t.(type) { |
| 85 | case *typeparams.TypeParam: |
| 86 | r := subst.replacements[t] |
| 87 | assert(r != nil, "type param without replacement encountered") |
| 88 | return r |
| 89 | |
| 90 | case *types.Basic: |
| 91 | return t |
| 92 | |
| 93 | case *types.Array: |
| 94 | if r := subst.typ(t.Elem()); r != t.Elem() { |
| 95 | return types.NewArray(r, t.Len()) |
| 96 | } |
| 97 | return t |
| 98 | |
| 99 | case *types.Slice: |
| 100 | if r := subst.typ(t.Elem()); r != t.Elem() { |
| 101 | return types.NewSlice(r) |
| 102 | } |
| 103 | return t |
| 104 | |
| 105 | case *types.Pointer: |
| 106 | if r := subst.typ(t.Elem()); r != t.Elem() { |
| 107 | return types.NewPointer(r) |
| 108 | } |
| 109 | return t |
| 110 | |
| 111 | case *types.Tuple: |
| 112 | return subst.tuple(t) |
| 113 | |
| 114 | case *types.Struct: |
| 115 | return subst.struct_(t) |
| 116 | |
| 117 | case *types.Map: |
| 118 | key := subst.typ(t.Key()) |
| 119 | elem := subst.typ(t.Elem()) |
| 120 | if key != t.Key() || elem != t.Elem() { |
| 121 | return types.NewMap(key, elem) |
| 122 | } |
| 123 | return t |
| 124 | |
| 125 | case *types.Chan: |
| 126 | if elem := subst.typ(t.Elem()); elem != t.Elem() { |
| 127 | return types.NewChan(t.Dir(), elem) |
| 128 | } |
| 129 | return t |
| 130 | |
| 131 | case *types.Signature: |
| 132 | return subst.signature(t) |
| 133 | |
| 134 | case *typeparams.Union: |
| 135 | return subst.union(t) |
| 136 | |
| 137 | case *types.Interface: |
| 138 | return subst.interface_(t) |
| 139 | |
| 140 | case *types.Named: |
| 141 | return subst.named(t) |
| 142 | |
| 143 | default: |
| 144 | panic("unreachable") |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | // types returns the result of {subst.typ(ts[i])}. |
| 149 | func (subst *subster) types(ts []types.Type) []types.Type { |
| 150 | res := make([]types.Type, len(ts)) |
| 151 | for i := range ts { |
| 152 | res[i] = subst.typ(ts[i]) |
| 153 | } |
| 154 | return res |
| 155 | } |
| 156 | |
| 157 | func (subst *subster) tuple(t *types.Tuple) *types.Tuple { |
| 158 | if t != nil { |
| 159 | if vars := subst.varlist(t); vars != nil { |
| 160 | return types.NewTuple(vars...) |
| 161 | } |
| 162 | } |
| 163 | return t |
| 164 | } |
| 165 | |
| 166 | type varlist interface { |
| 167 | At(i int) *types.Var |
| 168 | Len() int |
| 169 | } |
| 170 | |
| 171 | // fieldlist is an adapter for structs for the varlist interface. |
| 172 | type fieldlist struct { |
| 173 | str *types.Struct |
| 174 | } |
| 175 | |
| 176 | func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } |
| 177 | func (fl fieldlist) Len() int { return fl.str.NumFields() } |
| 178 | |
| 179 | func (subst *subster) struct_(t *types.Struct) *types.Struct { |
| 180 | if t != nil { |
| 181 | if fields := subst.varlist(fieldlist{t}); fields != nil { |
| 182 | tags := make([]string, t.NumFields()) |
| 183 | for i, n := 0, t.NumFields(); i < n; i++ { |
| 184 | tags[i] = t.Tag(i) |
| 185 | } |
| 186 | return types.NewStruct(fields, tags) |
| 187 | } |
| 188 | } |
| 189 | return t |
| 190 | } |
| 191 | |
| 192 | // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. |
| 193 | func (subst *subster) varlist(in varlist) []*types.Var { |
| 194 | var out []*types.Var // nil => no updates |
| 195 | for i, n := 0, in.Len(); i < n; i++ { |
| 196 | v := in.At(i) |
| 197 | w := subst.var_(v) |
| 198 | if v != w && out == nil { |
| 199 | out = make([]*types.Var, n) |
| 200 | for j := 0; j < i; j++ { |
| 201 | out[j] = in.At(j) |
| 202 | } |
| 203 | } |
| 204 | if out != nil { |
| 205 | out[i] = w |
| 206 | } |
| 207 | } |
| 208 | return out |
| 209 | } |
| 210 | |
| 211 | func (subst *subster) var_(v *types.Var) *types.Var { |
| 212 | if v != nil { |
| 213 | if typ := subst.typ(v.Type()); typ != v.Type() { |
| 214 | if v.IsField() { |
| 215 | return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) |
| 216 | } |
| 217 | return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) |
| 218 | } |
| 219 | } |
| 220 | return v |
| 221 | } |
| 222 | |
| 223 | func (subst *subster) union(u *typeparams.Union) *typeparams.Union { |
| 224 | var out []*typeparams.Term // nil => no updates |
| 225 | |
| 226 | for i, n := 0, u.Len(); i < n; i++ { |
| 227 | t := u.Term(i) |
| 228 | r := subst.typ(t.Type()) |
| 229 | if r != t.Type() && out == nil { |
| 230 | out = make([]*typeparams.Term, n) |
| 231 | for j := 0; j < i; j++ { |
| 232 | out[j] = u.Term(j) |
| 233 | } |
| 234 | } |
| 235 | if out != nil { |
| 236 | out[i] = typeparams.NewTerm(t.Tilde(), r) |
| 237 | } |
| 238 | } |
| 239 | |
| 240 | if out != nil { |
| 241 | return typeparams.NewUnion(out) |
| 242 | } |
| 243 | return u |
| 244 | } |
| 245 | |
| 246 | func (subst *subster) interface_(iface *types.Interface) *types.Interface { |
| 247 | if iface == nil { |
| 248 | return nil |
| 249 | } |
| 250 | |
| 251 | // methods for the interface. Initially nil if there is no known change needed. |
| 252 | // Signatures for the method where recv is nil. NewInterfaceType fills in the recievers. |
| 253 | var methods []*types.Func |
| 254 | initMethods := func(n int) { // copy first n explicit methods |
| 255 | methods = make([]*types.Func, iface.NumExplicitMethods()) |
| 256 | for i := 0; i < n; i++ { |
| 257 | f := iface.ExplicitMethod(i) |
| 258 | norecv := changeRecv(f.Type().(*types.Signature), nil) |
| 259 | methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) |
| 260 | } |
| 261 | } |
| 262 | for i := 0; i < iface.NumExplicitMethods(); i++ { |
| 263 | f := iface.ExplicitMethod(i) |
| 264 | // On interfaces, we need to cycle break on anonymous interface types |
| 265 | // being in a cycle with their signatures being in cycles with their recievers |
| 266 | // that do not go through a Named. |
| 267 | norecv := changeRecv(f.Type().(*types.Signature), nil) |
| 268 | sig := subst.typ(norecv) |
| 269 | if sig != norecv && methods == nil { |
| 270 | initMethods(i) |
| 271 | } |
| 272 | if methods != nil { |
| 273 | methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) |
| 274 | } |
| 275 | } |
| 276 | |
| 277 | var embeds []types.Type |
| 278 | initEmbeds := func(n int) { // copy first n embedded types |
| 279 | embeds = make([]types.Type, iface.NumEmbeddeds()) |
| 280 | for i := 0; i < n; i++ { |
| 281 | embeds[i] = iface.EmbeddedType(i) |
| 282 | } |
| 283 | } |
| 284 | for i := 0; i < iface.NumEmbeddeds(); i++ { |
| 285 | e := iface.EmbeddedType(i) |
| 286 | r := subst.typ(e) |
| 287 | if e != r && embeds == nil { |
| 288 | initEmbeds(i) |
| 289 | } |
| 290 | if embeds != nil { |
| 291 | embeds[i] = r |
| 292 | } |
| 293 | } |
| 294 | |
| 295 | if methods == nil && embeds == nil { |
| 296 | return iface |
| 297 | } |
| 298 | if methods == nil { |
| 299 | initMethods(iface.NumExplicitMethods()) |
| 300 | } |
| 301 | if embeds == nil { |
| 302 | initEmbeds(iface.NumEmbeddeds()) |
| 303 | } |
| 304 | return types.NewInterfaceType(methods, embeds).Complete() |
| 305 | } |
| 306 | |
| 307 | func (subst *subster) named(t *types.Named) types.Type { |
| 308 | // A name type may be: |
| 309 | // (1) ordinary (no type parameters, no type arguments), |
| 310 | // (2) generic (type parameters but no type arguments), or |
| 311 | // (3) instantiated (type parameters and type arguments). |
| 312 | tparams := typeparams.ForNamed(t) |
| 313 | if tparams.Len() == 0 { |
| 314 | // case (1) ordinary |
| 315 | |
| 316 | // Note: If Go allows for local type declarations in generic |
| 317 | // functions we may need to descend into underlying as well. |
| 318 | return t |
| 319 | } |
| 320 | targs := typeparams.NamedTypeArgs(t) |
| 321 | |
| 322 | // insts are arguments to instantiate using. |
| 323 | insts := make([]types.Type, tparams.Len()) |
| 324 | |
| 325 | // case (2) generic ==> targs.Len() == 0 |
| 326 | // Instantiating a generic with no type arguments should be unreachable. |
| 327 | // Please report a bug if you encounter this. |
| 328 | assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") |
| 329 | |
| 330 | // case (3) instantiated. |
| 331 | // Substitute into the type arguments and instantiate the replacements/ |
| 332 | // Example: |
| 333 | // type N[A any] func() A |
| 334 | // func Foo[T](g N[T]) {} |
| 335 | // To instantiate Foo[string], one goes through {T->string}. To get the type of g |
| 336 | // one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } |
| 337 | // to get {N with TypeArgs == {string} and typeparams == {A} }. |
| 338 | assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") |
| 339 | for i, n := 0, targs.Len(); i < n; i++ { |
| 340 | inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion |
| 341 | insts[i] = inst |
| 342 | } |
| 343 | r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false) |
| 344 | assert(err == nil, "failed to Instantiate Named type") |
| 345 | return r |
| 346 | } |
| 347 | |
| 348 | func (subst *subster) signature(t *types.Signature) types.Type { |
| 349 | tparams := typeparams.ForSignature(t) |
| 350 | |
| 351 | // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. |
| 352 | // |
| 353 | // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. |
| 354 | // To support tparams.Len() > 0, we just need to do the following [psuedocode]: |
| 355 | // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) |
| 356 | |
| 357 | assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") |
| 358 | |
| 359 | // Either: |
| 360 | // (1)non-generic function. |
| 361 | // no type params to substitute |
| 362 | // (2)generic method and recv needs to be substituted. |
| 363 | |
| 364 | // Recievers can be either: |
| 365 | // named |
| 366 | // pointer to named |
| 367 | // interface |
| 368 | // nil |
| 369 | // interface is the problematic case. We need to cycle break there! |
| 370 | recv := subst.var_(t.Recv()) |
| 371 | params := subst.tuple(t.Params()) |
| 372 | results := subst.tuple(t.Results()) |
| 373 | if recv != t.Recv() || params != t.Params() || results != t.Results() { |
| 374 | return typeparams.NewSignatureType(recv, nil, nil, params, results, t.Variadic()) |
| 375 | } |
| 376 | return t |
| 377 | } |
| 378 | |
| 379 | // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. |
| 380 | // Updates c to cache results. |
| 381 | func reaches(t types.Type, c map[types.Type]bool) (res bool) { |
| 382 | if c, ok := c[t]; ok { |
| 383 | return c |
| 384 | } |
| 385 | c[t] = false // prevent cycles |
| 386 | defer func() { |
| 387 | c[t] = res |
| 388 | }() |
| 389 | |
| 390 | switch t := t.(type) { |
| 391 | case *typeparams.TypeParam, *types.Basic: |
| 392 | // no-op => c == false |
| 393 | case *types.Array: |
| 394 | return reaches(t.Elem(), c) |
| 395 | case *types.Slice: |
| 396 | return reaches(t.Elem(), c) |
| 397 | case *types.Pointer: |
| 398 | return reaches(t.Elem(), c) |
| 399 | case *types.Tuple: |
| 400 | for i := 0; i < t.Len(); i++ { |
| 401 | if reaches(t.At(i).Type(), c) { |
| 402 | return true |
| 403 | } |
| 404 | } |
| 405 | case *types.Struct: |
| 406 | for i := 0; i < t.NumFields(); i++ { |
| 407 | if reaches(t.Field(i).Type(), c) { |
| 408 | return true |
| 409 | } |
| 410 | } |
| 411 | case *types.Map: |
| 412 | return reaches(t.Key(), c) || reaches(t.Elem(), c) |
| 413 | case *types.Chan: |
| 414 | return reaches(t.Elem(), c) |
| 415 | case *types.Signature: |
| 416 | if t.Recv() != nil && reaches(t.Recv().Type(), c) { |
| 417 | return true |
| 418 | } |
| 419 | return reaches(t.Params(), c) || reaches(t.Results(), c) |
| 420 | case *typeparams.Union: |
| 421 | for i := 0; i < t.Len(); i++ { |
| 422 | if reaches(t.Term(i).Type(), c) { |
| 423 | return true |
| 424 | } |
| 425 | } |
| 426 | case *types.Interface: |
| 427 | for i := 0; i < t.NumEmbeddeds(); i++ { |
| 428 | if reaches(t.Embedded(i), c) { |
| 429 | return true |
| 430 | } |
| 431 | } |
| 432 | for i := 0; i < t.NumExplicitMethods(); i++ { |
| 433 | if reaches(t.ExplicitMethod(i).Type(), c) { |
| 434 | return true |
| 435 | } |
| 436 | } |
| 437 | case *types.Named: |
| 438 | return reaches(t.Underlying(), c) |
| 439 | default: |
| 440 | panic("unreachable") |
| 441 | } |
| 442 | return false |
| 443 | } |
| 444 |
Members