| 1 | // Copyright 2012 The Go Authors. All rights reserved. |
|---|---|
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | // Package loopclosure defines an Analyzer that checks for references to |
| 6 | // enclosing loop variables from within nested functions. |
| 7 | package loopclosure |
| 8 | |
| 9 | import ( |
| 10 | "go/ast" |
| 11 | "go/types" |
| 12 | |
| 13 | "golang.org/x/tools/go/analysis" |
| 14 | "golang.org/x/tools/go/analysis/passes/inspect" |
| 15 | "golang.org/x/tools/go/ast/inspector" |
| 16 | "golang.org/x/tools/go/types/typeutil" |
| 17 | ) |
| 18 | |
| 19 | const Doc = `check references to loop variables from within nested functions |
| 20 | |
| 21 | This analyzer reports places where a function literal references the |
| 22 | iteration variable of an enclosing loop, and the loop calls the function |
| 23 | in such a way (e.g. with go or defer) that it may outlive the loop |
| 24 | iteration and possibly observe the wrong value of the variable. |
| 25 | |
| 26 | In this example, all the deferred functions run after the loop has |
| 27 | completed, so all observe the final value of v. |
| 28 | |
| 29 | for _, v := range list { |
| 30 | defer func() { |
| 31 | use(v) // incorrect |
| 32 | }() |
| 33 | } |
| 34 | |
| 35 | One fix is to create a new variable for each iteration of the loop: |
| 36 | |
| 37 | for _, v := range list { |
| 38 | v := v // new var per iteration |
| 39 | defer func() { |
| 40 | use(v) // ok |
| 41 | }() |
| 42 | } |
| 43 | |
| 44 | The next example uses a go statement and has a similar problem. |
| 45 | In addition, it has a data race because the loop updates v |
| 46 | concurrent with the goroutines accessing it. |
| 47 | |
| 48 | for _, v := range elem { |
| 49 | go func() { |
| 50 | use(v) // incorrect, and a data race |
| 51 | }() |
| 52 | } |
| 53 | |
| 54 | A fix is the same as before. The checker also reports problems |
| 55 | in goroutines started by golang.org/x/sync/errgroup.Group. |
| 56 | A hard-to-spot variant of this form is common in parallel tests: |
| 57 | |
| 58 | func Test(t *testing.T) { |
| 59 | for _, test := range tests { |
| 60 | t.Run(test.name, func(t *testing.T) { |
| 61 | t.Parallel() |
| 62 | use(test) // incorrect, and a data race |
| 63 | }) |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | The t.Parallel() call causes the rest of the function to execute |
| 68 | concurrent with the loop. |
| 69 | |
| 70 | The analyzer reports references only in the last statement, |
| 71 | as it is not deep enough to understand the effects of subsequent |
| 72 | statements that might render the reference benign. |
| 73 | ("Last statement" is defined recursively in compound |
| 74 | statements such as if, switch, and select.) |
| 75 | |
| 76 | See: https://golang.org/doc/go_faq.html#closures_and_goroutines` |
| 77 | |
| 78 | var Analyzer = &analysis.Analyzer{ |
| 79 | Name: "loopclosure", |
| 80 | Doc: Doc, |
| 81 | Requires: []*analysis.Analyzer{inspect.Analyzer}, |
| 82 | Run: run, |
| 83 | } |
| 84 | |
| 85 | func run(pass *analysis.Pass) (interface{}, error) { |
| 86 | inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) |
| 87 | |
| 88 | nodeFilter := []ast.Node{ |
| 89 | (*ast.RangeStmt)(nil), |
| 90 | (*ast.ForStmt)(nil), |
| 91 | } |
| 92 | inspect.Preorder(nodeFilter, func(n ast.Node) { |
| 93 | // Find the variables updated by the loop statement. |
| 94 | var vars []types.Object |
| 95 | addVar := func(expr ast.Expr) { |
| 96 | if id, _ := expr.(*ast.Ident); id != nil { |
| 97 | if obj := pass.TypesInfo.ObjectOf(id); obj != nil { |
| 98 | vars = append(vars, obj) |
| 99 | } |
| 100 | } |
| 101 | } |
| 102 | var body *ast.BlockStmt |
| 103 | switch n := n.(type) { |
| 104 | case *ast.RangeStmt: |
| 105 | body = n.Body |
| 106 | addVar(n.Key) |
| 107 | addVar(n.Value) |
| 108 | case *ast.ForStmt: |
| 109 | body = n.Body |
| 110 | switch post := n.Post.(type) { |
| 111 | case *ast.AssignStmt: |
| 112 | // e.g. for p = head; p != nil; p = p.next |
| 113 | for _, lhs := range post.Lhs { |
| 114 | addVar(lhs) |
| 115 | } |
| 116 | case *ast.IncDecStmt: |
| 117 | // e.g. for i := 0; i < n; i++ |
| 118 | addVar(post.X) |
| 119 | } |
| 120 | } |
| 121 | if vars == nil { |
| 122 | return |
| 123 | } |
| 124 | |
| 125 | // Inspect statements to find function literals that may be run outside of |
| 126 | // the current loop iteration. |
| 127 | // |
| 128 | // For go, defer, and errgroup.Group.Go, we ignore all but the last |
| 129 | // statement, because it's hard to prove go isn't followed by wait, or |
| 130 | // defer by return. "Last" is defined recursively. |
| 131 | // |
| 132 | // TODO: consider allowing the "last" go/defer/Go statement to be followed by |
| 133 | // N "trivial" statements, possibly under a recursive definition of "trivial" |
| 134 | // so that that checker could, for example, conclude that a go statement is |
| 135 | // followed by an if statement made of only trivial statements and trivial expressions, |
| 136 | // and hence the go statement could still be checked. |
| 137 | forEachLastStmt(body.List, func(last ast.Stmt) { |
| 138 | var stmts []ast.Stmt |
| 139 | switch s := last.(type) { |
| 140 | case *ast.GoStmt: |
| 141 | stmts = litStmts(s.Call.Fun) |
| 142 | case *ast.DeferStmt: |
| 143 | stmts = litStmts(s.Call.Fun) |
| 144 | case *ast.ExprStmt: // check for errgroup.Group.Go |
| 145 | if call, ok := s.X.(*ast.CallExpr); ok { |
| 146 | stmts = litStmts(goInvoke(pass.TypesInfo, call)) |
| 147 | } |
| 148 | } |
| 149 | for _, stmt := range stmts { |
| 150 | reportCaptured(pass, vars, stmt) |
| 151 | } |
| 152 | }) |
| 153 | |
| 154 | // Also check for testing.T.Run (with T.Parallel). |
| 155 | // We consider every t.Run statement in the loop body, because there is |
| 156 | // no commonly used mechanism for synchronizing parallel subtests. |
| 157 | // It is of course theoretically possible to synchronize parallel subtests, |
| 158 | // though such a pattern is likely to be exceedingly rare as it would be |
| 159 | // fighting against the test runner. |
| 160 | for _, s := range body.List { |
| 161 | switch s := s.(type) { |
| 162 | case *ast.ExprStmt: |
| 163 | if call, ok := s.X.(*ast.CallExpr); ok { |
| 164 | for _, stmt := range parallelSubtest(pass.TypesInfo, call) { |
| 165 | reportCaptured(pass, vars, stmt) |
| 166 | } |
| 167 | |
| 168 | } |
| 169 | } |
| 170 | } |
| 171 | }) |
| 172 | return nil, nil |
| 173 | } |
| 174 | |
| 175 | // reportCaptured reports a diagnostic stating a loop variable |
| 176 | // has been captured by a func literal if checkStmt has escaping |
| 177 | // references to vars. vars is expected to be variables updated by a loop statement, |
| 178 | // and checkStmt is expected to be a statements from the body of a func literal in the loop. |
| 179 | func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) { |
| 180 | ast.Inspect(checkStmt, func(n ast.Node) bool { |
| 181 | id, ok := n.(*ast.Ident) |
| 182 | if !ok { |
| 183 | return true |
| 184 | } |
| 185 | obj := pass.TypesInfo.Uses[id] |
| 186 | if obj == nil { |
| 187 | return true |
| 188 | } |
| 189 | for _, v := range vars { |
| 190 | if v == obj { |
| 191 | pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name) |
| 192 | } |
| 193 | } |
| 194 | return true |
| 195 | }) |
| 196 | } |
| 197 | |
| 198 | // forEachLastStmt calls onLast on each "last" statement in a list of statements. |
| 199 | // "Last" is defined recursively so, for example, if the last statement is |
| 200 | // a switch statement, then each switch case is also visited to examine |
| 201 | // its last statements. |
| 202 | func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) { |
| 203 | if len(stmts) == 0 { |
| 204 | return |
| 205 | } |
| 206 | |
| 207 | s := stmts[len(stmts)-1] |
| 208 | switch s := s.(type) { |
| 209 | case *ast.IfStmt: |
| 210 | loop: |
| 211 | for { |
| 212 | forEachLastStmt(s.Body.List, onLast) |
| 213 | switch e := s.Else.(type) { |
| 214 | case *ast.BlockStmt: |
| 215 | forEachLastStmt(e.List, onLast) |
| 216 | break loop |
| 217 | case *ast.IfStmt: |
| 218 | s = e |
| 219 | case nil: |
| 220 | break loop |
| 221 | } |
| 222 | } |
| 223 | case *ast.ForStmt: |
| 224 | forEachLastStmt(s.Body.List, onLast) |
| 225 | case *ast.RangeStmt: |
| 226 | forEachLastStmt(s.Body.List, onLast) |
| 227 | case *ast.SwitchStmt: |
| 228 | for _, c := range s.Body.List { |
| 229 | cc := c.(*ast.CaseClause) |
| 230 | forEachLastStmt(cc.Body, onLast) |
| 231 | } |
| 232 | case *ast.TypeSwitchStmt: |
| 233 | for _, c := range s.Body.List { |
| 234 | cc := c.(*ast.CaseClause) |
| 235 | forEachLastStmt(cc.Body, onLast) |
| 236 | } |
| 237 | case *ast.SelectStmt: |
| 238 | for _, c := range s.Body.List { |
| 239 | cc := c.(*ast.CommClause) |
| 240 | forEachLastStmt(cc.Body, onLast) |
| 241 | } |
| 242 | default: |
| 243 | onLast(s) |
| 244 | } |
| 245 | } |
| 246 | |
| 247 | // litStmts returns all statements from the function body of a function |
| 248 | // literal. |
| 249 | // |
| 250 | // If fun is not a function literal, it returns nil. |
| 251 | func litStmts(fun ast.Expr) []ast.Stmt { |
| 252 | lit, _ := fun.(*ast.FuncLit) |
| 253 | if lit == nil { |
| 254 | return nil |
| 255 | } |
| 256 | return lit.Body.List |
| 257 | } |
| 258 | |
| 259 | // goInvoke returns a function expression that would be called asynchronously |
| 260 | // (but not awaited) in another goroutine as a consequence of the call. |
| 261 | // For example, given the g.Go call below, it returns the function literal expression. |
| 262 | // |
| 263 | // import "sync/errgroup" |
| 264 | // var g errgroup.Group |
| 265 | // g.Go(func() error { ... }) |
| 266 | // |
| 267 | // Currently only "golang.org/x/sync/errgroup.Group()" is considered. |
| 268 | func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr { |
| 269 | if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") { |
| 270 | return nil |
| 271 | } |
| 272 | return call.Args[0] |
| 273 | } |
| 274 | |
| 275 | // parallelSubtest returns statements that can be easily proven to execute |
| 276 | // concurrently via the go test runner, as t.Run has been invoked with a |
| 277 | // function literal that calls t.Parallel. |
| 278 | // |
| 279 | // In practice, users rely on the fact that statements before the call to |
| 280 | // t.Parallel are synchronous. For example by declaring test := test inside the |
| 281 | // function literal, but before the call to t.Parallel. |
| 282 | // |
| 283 | // Therefore, we only flag references in statements that are obviously |
| 284 | // dominated by a call to t.Parallel. As a simple heuristic, we only consider |
| 285 | // statements following the final labeled statement in the function body, to |
| 286 | // avoid scenarios where a jump would cause either the call to t.Parallel or |
| 287 | // the problematic reference to be skipped. |
| 288 | // |
| 289 | // import "testing" |
| 290 | // |
| 291 | // func TestFoo(t *testing.T) { |
| 292 | // tests := []int{0, 1, 2} |
| 293 | // for i, test := range tests { |
| 294 | // t.Run("subtest", func(t *testing.T) { |
| 295 | // println(i, test) // OK |
| 296 | // t.Parallel() |
| 297 | // println(i, test) // Not OK |
| 298 | // }) |
| 299 | // } |
| 300 | // } |
| 301 | func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt { |
| 302 | if !isMethodCall(info, call, "testing", "T", "Run") { |
| 303 | return nil |
| 304 | } |
| 305 | |
| 306 | lit, _ := call.Args[1].(*ast.FuncLit) |
| 307 | if lit == nil { |
| 308 | return nil |
| 309 | } |
| 310 | |
| 311 | // Capture the *testing.T object for the first argument to the function |
| 312 | // literal. |
| 313 | if len(lit.Type.Params.List[0].Names) == 0 { |
| 314 | return nil |
| 315 | } |
| 316 | |
| 317 | tObj := info.Defs[lit.Type.Params.List[0].Names[0]] |
| 318 | if tObj == nil { |
| 319 | return nil |
| 320 | } |
| 321 | |
| 322 | // Match statements that occur after a call to t.Parallel following the final |
| 323 | // labeled statement in the function body. |
| 324 | // |
| 325 | // We iterate over lit.Body.List to have a simple, fast and "frequent enough" |
| 326 | // dominance relationship for t.Parallel(): lit.Body.List[i] dominates |
| 327 | // lit.Body.List[j] for i < j unless there is a jump. |
| 328 | var stmts []ast.Stmt |
| 329 | afterParallel := false |
| 330 | for _, stmt := range lit.Body.List { |
| 331 | stmt, labeled := unlabel(stmt) |
| 332 | if labeled { |
| 333 | // Reset: naively we don't know if a jump could have caused the |
| 334 | // previously considered statements to be skipped. |
| 335 | stmts = nil |
| 336 | afterParallel = false |
| 337 | } |
| 338 | |
| 339 | if afterParallel { |
| 340 | stmts = append(stmts, stmt) |
| 341 | continue |
| 342 | } |
| 343 | |
| 344 | // Check if stmt is a call to t.Parallel(), for the correct t. |
| 345 | exprStmt, ok := stmt.(*ast.ExprStmt) |
| 346 | if !ok { |
| 347 | continue |
| 348 | } |
| 349 | expr := exprStmt.X |
| 350 | if isMethodCall(info, expr, "testing", "T", "Parallel") { |
| 351 | call, _ := expr.(*ast.CallExpr) |
| 352 | if call == nil { |
| 353 | continue |
| 354 | } |
| 355 | x, _ := call.Fun.(*ast.SelectorExpr) |
| 356 | if x == nil { |
| 357 | continue |
| 358 | } |
| 359 | id, _ := x.X.(*ast.Ident) |
| 360 | if id == nil { |
| 361 | continue |
| 362 | } |
| 363 | if info.Uses[id] == tObj { |
| 364 | afterParallel = true |
| 365 | } |
| 366 | } |
| 367 | } |
| 368 | |
| 369 | return stmts |
| 370 | } |
| 371 | |
| 372 | // unlabel returns the inner statement for the possibly labeled statement stmt, |
| 373 | // stripping any (possibly nested) *ast.LabeledStmt wrapper. |
| 374 | // |
| 375 | // The second result reports whether stmt was an *ast.LabeledStmt. |
| 376 | func unlabel(stmt ast.Stmt) (ast.Stmt, bool) { |
| 377 | labeled := false |
| 378 | for { |
| 379 | labelStmt, ok := stmt.(*ast.LabeledStmt) |
| 380 | if !ok { |
| 381 | return stmt, labeled |
| 382 | } |
| 383 | labeled = true |
| 384 | stmt = labelStmt.Stmt |
| 385 | } |
| 386 | } |
| 387 | |
| 388 | // isMethodCall reports whether expr is a method call of |
| 389 | // <pkgPath>.<typeName>.<method>. |
| 390 | func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool { |
| 391 | call, ok := expr.(*ast.CallExpr) |
| 392 | if !ok { |
| 393 | return false |
| 394 | } |
| 395 | |
| 396 | // Check that we are calling a method <method> |
| 397 | f := typeutil.StaticCallee(info, call) |
| 398 | if f == nil || f.Name() != method { |
| 399 | return false |
| 400 | } |
| 401 | recv := f.Type().(*types.Signature).Recv() |
| 402 | if recv == nil { |
| 403 | return false |
| 404 | } |
| 405 | |
| 406 | // Check that the receiver is a <pkgPath>.<typeName> or |
| 407 | // *<pkgPath>.<typeName>. |
| 408 | rtype := recv.Type() |
| 409 | if ptr, ok := recv.Type().(*types.Pointer); ok { |
| 410 | rtype = ptr.Elem() |
| 411 | } |
| 412 | named, ok := rtype.(*types.Named) |
| 413 | if !ok { |
| 414 | return false |
| 415 | } |
| 416 | if named.Obj().Name() != typeName { |
| 417 | return false |
| 418 | } |
| 419 | pkg := f.Pkg() |
| 420 | if pkg == nil { |
| 421 | return false |
| 422 | } |
| 423 | if pkg.Path() != pkgPath { |
| 424 | return false |
| 425 | } |
| 426 | |
| 427 | return true |
| 428 | } |
| 429 |
Members