| 1 | // Copyright 2014 The Go Authors. All rights reserved. |
|---|---|
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package eg |
| 6 | |
| 7 | // This file defines the AST rewriting pass. |
| 8 | // Most of it was plundered directly from |
| 9 | // $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution). |
| 10 | |
| 11 | import ( |
| 12 | "fmt" |
| 13 | "go/ast" |
| 14 | "go/token" |
| 15 | "go/types" |
| 16 | "os" |
| 17 | "reflect" |
| 18 | "sort" |
| 19 | "strconv" |
| 20 | "strings" |
| 21 | |
| 22 | "golang.org/x/tools/go/ast/astutil" |
| 23 | ) |
| 24 | |
| 25 | // transformItem takes a reflect.Value representing a variable of type ast.Node |
| 26 | // transforms its child elements recursively with apply, and then transforms the |
| 27 | // actual element if it contains an expression. |
| 28 | func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { |
| 29 | // don't bother if val is invalid to start with |
| 30 | if !rv.IsValid() { |
| 31 | return reflect.Value{}, false, nil |
| 32 | } |
| 33 | |
| 34 | rv, changed, newEnv := tr.apply(tr.transformItem, rv) |
| 35 | |
| 36 | e := rvToExpr(rv) |
| 37 | if e == nil { |
| 38 | return rv, changed, newEnv |
| 39 | } |
| 40 | |
| 41 | savedEnv := tr.env |
| 42 | tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs |
| 43 | |
| 44 | if tr.matchExpr(tr.before, e) { |
| 45 | if tr.verbose { |
| 46 | fmt.Fprintf(os.Stderr, "%s matches %s", |
| 47 | astString(tr.fset, tr.before), astString(tr.fset, e)) |
| 48 | if len(tr.env) > 0 { |
| 49 | fmt.Fprintf(os.Stderr, " with:") |
| 50 | for name, ast := range tr.env { |
| 51 | fmt.Fprintf(os.Stderr, " %s->%s", |
| 52 | name, astString(tr.fset, ast)) |
| 53 | } |
| 54 | } |
| 55 | fmt.Fprintf(os.Stderr, "\n") |
| 56 | } |
| 57 | tr.nsubsts++ |
| 58 | |
| 59 | // Clone the replacement tree, performing parameter substitution. |
| 60 | // We update all positions to n.Pos() to aid comment placement. |
| 61 | rv = tr.subst(tr.env, reflect.ValueOf(tr.after), |
| 62 | reflect.ValueOf(e.Pos())) |
| 63 | changed = true |
| 64 | newEnv = tr.env |
| 65 | } |
| 66 | tr.env = savedEnv |
| 67 | |
| 68 | return rv, changed, newEnv |
| 69 | } |
| 70 | |
| 71 | // Transform applies the transformation to the specified parsed file, |
| 72 | // whose type information is supplied in info, and returns the number |
| 73 | // of replacements that were made. |
| 74 | // |
| 75 | // It mutates the AST in place (the identity of the root node is |
| 76 | // unchanged), and may add nodes for which no type information is |
| 77 | // available in info. |
| 78 | // |
| 79 | // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go. |
| 80 | func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int { |
| 81 | if !tr.seenInfos[info] { |
| 82 | tr.seenInfos[info] = true |
| 83 | mergeTypeInfo(tr.info, info) |
| 84 | } |
| 85 | tr.currentPkg = pkg |
| 86 | tr.nsubsts = 0 |
| 87 | |
| 88 | if tr.verbose { |
| 89 | fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before)) |
| 90 | fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after)) |
| 91 | fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts) |
| 92 | } |
| 93 | |
| 94 | o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file)) |
| 95 | if changed { |
| 96 | panic("BUG") |
| 97 | } |
| 98 | file2 := o.Interface().(*ast.File) |
| 99 | |
| 100 | // By construction, the root node is unchanged. |
| 101 | if file != file2 { |
| 102 | panic("BUG") |
| 103 | } |
| 104 | |
| 105 | // Add any necessary imports. |
| 106 | // TODO(adonovan): remove no-longer needed imports too. |
| 107 | if tr.nsubsts > 0 { |
| 108 | pkgs := make(map[string]*types.Package) |
| 109 | for obj := range tr.importedObjs { |
| 110 | pkgs[obj.Pkg().Path()] = obj.Pkg() |
| 111 | } |
| 112 | |
| 113 | for _, imp := range file.Imports { |
| 114 | path, _ := strconv.Unquote(imp.Path.Value) |
| 115 | delete(pkgs, path) |
| 116 | } |
| 117 | delete(pkgs, pkg.Path()) // don't import self |
| 118 | |
| 119 | // NB: AddImport may completely replace the AST! |
| 120 | // It thus renders info and tr.info no longer relevant to file. |
| 121 | var paths []string |
| 122 | for path := range pkgs { |
| 123 | paths = append(paths, path) |
| 124 | } |
| 125 | sort.Strings(paths) |
| 126 | for _, path := range paths { |
| 127 | astutil.AddImport(tr.fset, file, path) |
| 128 | } |
| 129 | } |
| 130 | |
| 131 | tr.currentPkg = nil |
| 132 | |
| 133 | return tr.nsubsts |
| 134 | } |
| 135 | |
| 136 | // setValue is a wrapper for x.SetValue(y); it protects |
| 137 | // the caller from panics if x cannot be changed to y. |
| 138 | func setValue(x, y reflect.Value) { |
| 139 | // don't bother if y is invalid to start with |
| 140 | if !y.IsValid() { |
| 141 | return |
| 142 | } |
| 143 | defer func() { |
| 144 | if x := recover(); x != nil { |
| 145 | if s, ok := x.(string); ok && |
| 146 | (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) { |
| 147 | // x cannot be set to y - ignore this rewrite |
| 148 | return |
| 149 | } |
| 150 | panic(x) |
| 151 | } |
| 152 | }() |
| 153 | x.Set(y) |
| 154 | } |
| 155 | |
| 156 | // Values/types for special cases. |
| 157 | var ( |
| 158 | objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) |
| 159 | scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) |
| 160 | |
| 161 | identType = reflect.TypeOf((*ast.Ident)(nil)) |
| 162 | selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil)) |
| 163 | objectPtrType = reflect.TypeOf((*ast.Object)(nil)) |
| 164 | statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem() |
| 165 | positionType = reflect.TypeOf(token.NoPos) |
| 166 | scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) |
| 167 | ) |
| 168 | |
| 169 | // apply replaces each AST field x in val with f(x), returning val. |
| 170 | // To avoid extra conversions, f operates on the reflect.Value form. |
| 171 | // f takes a reflect.Value representing the variable to modify of type ast.Node. |
| 172 | // It returns a reflect.Value containing the transformed value of type ast.Node, |
| 173 | // whether any change was made, and a map of identifiers to ast.Expr (so we can |
| 174 | // do contextually correct substitutions in the parent statements). |
| 175 | func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { |
| 176 | if !val.IsValid() { |
| 177 | return reflect.Value{}, false, nil |
| 178 | } |
| 179 | |
| 180 | // *ast.Objects introduce cycles and are likely incorrect after |
| 181 | // rewrite; don't follow them but replace with nil instead |
| 182 | if val.Type() == objectPtrType { |
| 183 | return objectPtrNil, false, nil |
| 184 | } |
| 185 | |
| 186 | // similarly for scopes: they are likely incorrect after a rewrite; |
| 187 | // replace them with nil |
| 188 | if val.Type() == scopePtrType { |
| 189 | return scopePtrNil, false, nil |
| 190 | } |
| 191 | |
| 192 | switch v := reflect.Indirect(val); v.Kind() { |
| 193 | case reflect.Slice: |
| 194 | // no possible rewriting of statements. |
| 195 | if v.Type().Elem() != statementType { |
| 196 | changed := false |
| 197 | var envp map[string]ast.Expr |
| 198 | for i := 0; i < v.Len(); i++ { |
| 199 | e := v.Index(i) |
| 200 | o, localchanged, env := f(e) |
| 201 | if localchanged { |
| 202 | changed = true |
| 203 | // we clobber envp here, |
| 204 | // which means if we have two successive |
| 205 | // replacements inside the same statement |
| 206 | // we will only generate the setup for one of them. |
| 207 | envp = env |
| 208 | } |
| 209 | setValue(e, o) |
| 210 | } |
| 211 | return val, changed, envp |
| 212 | } |
| 213 | |
| 214 | // statements are rewritten. |
| 215 | var out []ast.Stmt |
| 216 | for i := 0; i < v.Len(); i++ { |
| 217 | e := v.Index(i) |
| 218 | o, changed, env := f(e) |
| 219 | if changed { |
| 220 | for _, s := range tr.afterStmts { |
| 221 | t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface() |
| 222 | out = append(out, t.(ast.Stmt)) |
| 223 | } |
| 224 | } |
| 225 | setValue(e, o) |
| 226 | out = append(out, e.Interface().(ast.Stmt)) |
| 227 | } |
| 228 | return reflect.ValueOf(out), false, nil |
| 229 | case reflect.Struct: |
| 230 | changed := false |
| 231 | var envp map[string]ast.Expr |
| 232 | for i := 0; i < v.NumField(); i++ { |
| 233 | e := v.Field(i) |
| 234 | o, localchanged, env := f(e) |
| 235 | if localchanged { |
| 236 | changed = true |
| 237 | envp = env |
| 238 | } |
| 239 | setValue(e, o) |
| 240 | } |
| 241 | return val, changed, envp |
| 242 | case reflect.Interface: |
| 243 | e := v.Elem() |
| 244 | o, changed, env := f(e) |
| 245 | setValue(v, o) |
| 246 | return val, changed, env |
| 247 | } |
| 248 | return val, false, nil |
| 249 | } |
| 250 | |
| 251 | // subst returns a copy of (replacement) pattern with values from env |
| 252 | // substituted in place of wildcards and pos used as the position of |
| 253 | // tokens from the pattern. if env == nil, subst returns a copy of |
| 254 | // pattern and doesn't change the line number information. |
| 255 | func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value { |
| 256 | if !pattern.IsValid() { |
| 257 | return reflect.Value{} |
| 258 | } |
| 259 | |
| 260 | // *ast.Objects introduce cycles and are likely incorrect after |
| 261 | // rewrite; don't follow them but replace with nil instead |
| 262 | if pattern.Type() == objectPtrType { |
| 263 | return objectPtrNil |
| 264 | } |
| 265 | |
| 266 | // similarly for scopes: they are likely incorrect after a rewrite; |
| 267 | // replace them with nil |
| 268 | if pattern.Type() == scopePtrType { |
| 269 | return scopePtrNil |
| 270 | } |
| 271 | |
| 272 | // Wildcard gets replaced with map value. |
| 273 | if env != nil && pattern.Type() == identType { |
| 274 | id := pattern.Interface().(*ast.Ident) |
| 275 | if old, ok := env[id.Name]; ok { |
| 276 | return tr.subst(nil, reflect.ValueOf(old), reflect.Value{}) |
| 277 | } |
| 278 | } |
| 279 | |
| 280 | // Emit qualified identifiers in the pattern by appropriate |
| 281 | // (possibly qualified) identifier in the input. |
| 282 | // |
| 283 | // The template cannot contain dot imports, so all identifiers |
| 284 | // for imported objects are explicitly qualified. |
| 285 | // |
| 286 | // We assume (unsoundly) that there are no dot or named |
| 287 | // imports in the input code, nor are any imported package |
| 288 | // names shadowed, so the usual normal qualified identifier |
| 289 | // syntax may be used. |
| 290 | // TODO(adonovan): fix: avoid this assumption. |
| 291 | // |
| 292 | // A refactoring may be applied to a package referenced by the |
| 293 | // template. Objects belonging to the current package are |
| 294 | // denoted by unqualified identifiers. |
| 295 | // |
| 296 | if tr.importedObjs != nil && pattern.Type() == selectorExprType { |
| 297 | obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info) |
| 298 | if obj != nil { |
| 299 | if sel, ok := tr.importedObjs[obj]; ok { |
| 300 | var id ast.Expr |
| 301 | if obj.Pkg() == tr.currentPkg { |
| 302 | id = sel.Sel // unqualified |
| 303 | } else { |
| 304 | id = sel // pkg-qualified |
| 305 | } |
| 306 | |
| 307 | // Return a clone of id. |
| 308 | saved := tr.importedObjs |
| 309 | tr.importedObjs = nil // break cycle |
| 310 | r := tr.subst(nil, reflect.ValueOf(id), pos) |
| 311 | tr.importedObjs = saved |
| 312 | return r |
| 313 | } |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | if pos.IsValid() && pattern.Type() == positionType { |
| 318 | // use new position only if old position was valid in the first place |
| 319 | if old := pattern.Interface().(token.Pos); !old.IsValid() { |
| 320 | return pattern |
| 321 | } |
| 322 | return pos |
| 323 | } |
| 324 | |
| 325 | // Otherwise copy. |
| 326 | switch p := pattern; p.Kind() { |
| 327 | case reflect.Slice: |
| 328 | v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) |
| 329 | for i := 0; i < p.Len(); i++ { |
| 330 | v.Index(i).Set(tr.subst(env, p.Index(i), pos)) |
| 331 | } |
| 332 | return v |
| 333 | |
| 334 | case reflect.Struct: |
| 335 | v := reflect.New(p.Type()).Elem() |
| 336 | for i := 0; i < p.NumField(); i++ { |
| 337 | v.Field(i).Set(tr.subst(env, p.Field(i), pos)) |
| 338 | } |
| 339 | return v |
| 340 | |
| 341 | case reflect.Ptr: |
| 342 | v := reflect.New(p.Type()).Elem() |
| 343 | if elem := p.Elem(); elem.IsValid() { |
| 344 | v.Set(tr.subst(env, elem, pos).Addr()) |
| 345 | } |
| 346 | |
| 347 | // Duplicate type information for duplicated ast.Expr. |
| 348 | // All ast.Node implementations are *structs, |
| 349 | // so this case catches them all. |
| 350 | if e := rvToExpr(v); e != nil { |
| 351 | updateTypeInfo(tr.info, e, p.Interface().(ast.Expr)) |
| 352 | } |
| 353 | return v |
| 354 | |
| 355 | case reflect.Interface: |
| 356 | v := reflect.New(p.Type()).Elem() |
| 357 | if elem := p.Elem(); elem.IsValid() { |
| 358 | v.Set(tr.subst(env, elem, pos)) |
| 359 | } |
| 360 | return v |
| 361 | } |
| 362 | |
| 363 | return pattern |
| 364 | } |
| 365 | |
| 366 | // -- utilities ------------------------------------------------------- |
| 367 | |
| 368 | func rvToExpr(rv reflect.Value) ast.Expr { |
| 369 | if rv.CanInterface() { |
| 370 | if e, ok := rv.Interface().(ast.Expr); ok { |
| 371 | return e |
| 372 | } |
| 373 | } |
| 374 | return nil |
| 375 | } |
| 376 | |
| 377 | // updateTypeInfo duplicates type information for the existing AST old |
| 378 | // so that it also applies to duplicated AST new. |
| 379 | func updateTypeInfo(info *types.Info, new, old ast.Expr) { |
| 380 | switch new := new.(type) { |
| 381 | case *ast.Ident: |
| 382 | orig := old.(*ast.Ident) |
| 383 | if obj, ok := info.Defs[orig]; ok { |
| 384 | info.Defs[new] = obj |
| 385 | } |
| 386 | if obj, ok := info.Uses[orig]; ok { |
| 387 | info.Uses[new] = obj |
| 388 | } |
| 389 | |
| 390 | case *ast.SelectorExpr: |
| 391 | orig := old.(*ast.SelectorExpr) |
| 392 | if sel, ok := info.Selections[orig]; ok { |
| 393 | info.Selections[new] = sel |
| 394 | } |
| 395 | } |
| 396 | |
| 397 | if tv, ok := info.Types[old]; ok { |
| 398 | info.Types[new] = tv |
| 399 | } |
| 400 | } |
| 401 |
Members