| 1 | // Copyright 2021 The Go Authors. All rights reserved. |
|---|---|
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package generator |
| 6 | |
| 7 | import ( |
| 8 | "fmt" |
| 9 | "math/rand" |
| 10 | "os" |
| 11 | "runtime" |
| 12 | "strings" |
| 13 | ) |
| 14 | |
| 15 | const ( |
| 16 | RandCtlNochecks = 0 |
| 17 | RandCtlChecks = 1 << iota |
| 18 | RandCtlCapture |
| 19 | RandCtlPanic |
| 20 | ) |
| 21 | |
| 22 | func NewWrapRand(seed int64, ctl int) *wraprand { |
| 23 | rand.Seed(seed) |
| 24 | return &wraprand{seed: seed, ctl: ctl} |
| 25 | } |
| 26 | |
| 27 | type wraprand struct { |
| 28 | f32calls int |
| 29 | f64calls int |
| 30 | intncalls int |
| 31 | seed int64 |
| 32 | tag string |
| 33 | calls []string |
| 34 | ctl int |
| 35 | } |
| 36 | |
| 37 | func (w *wraprand) captureCall(tag string, val string) { |
| 38 | call := tag + ": " + val + "\n" |
| 39 | pc := make([]uintptr, 10) |
| 40 | n := runtime.Callers(1, pc) |
| 41 | if n == 0 { |
| 42 | panic("why?") |
| 43 | } |
| 44 | pc = pc[:n] // pass only valid pcs to runtime.CallersFrames |
| 45 | frames := runtime.CallersFrames(pc) |
| 46 | for { |
| 47 | frame, more := frames.Next() |
| 48 | if strings.Contains(frame.File, "testing.") { |
| 49 | break |
| 50 | } |
| 51 | call += fmt.Sprintf("%s %s:%d\n", frame.Function, frame.File, frame.Line) |
| 52 | if !more { |
| 53 | break |
| 54 | } |
| 55 | |
| 56 | } |
| 57 | w.calls = append(w.calls, call) |
| 58 | } |
| 59 | |
| 60 | func (w *wraprand) Intn(n int64) int64 { |
| 61 | w.intncalls++ |
| 62 | rv := rand.Int63n(n) |
| 63 | if w.ctl&RandCtlCapture != 0 { |
| 64 | w.captureCall("Intn", fmt.Sprintf("%d", rv)) |
| 65 | } |
| 66 | return rv |
| 67 | } |
| 68 | |
| 69 | func (w *wraprand) Float32() float32 { |
| 70 | w.f32calls++ |
| 71 | rv := rand.Float32() |
| 72 | if w.ctl&RandCtlCapture != 0 { |
| 73 | w.captureCall("Float32", fmt.Sprintf("%f", rv)) |
| 74 | } |
| 75 | return rv |
| 76 | } |
| 77 | |
| 78 | func (w *wraprand) NormFloat64() float64 { |
| 79 | w.f64calls++ |
| 80 | rv := rand.NormFloat64() |
| 81 | if w.ctl&RandCtlCapture != 0 { |
| 82 | w.captureCall("NormFloat64", fmt.Sprintf("%f", rv)) |
| 83 | } |
| 84 | return rv |
| 85 | } |
| 86 | |
| 87 | func (w *wraprand) emitCalls(fn string) { |
| 88 | outf, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) |
| 89 | if err != nil { |
| 90 | panic(err) |
| 91 | } |
| 92 | for _, c := range w.calls { |
| 93 | fmt.Fprint(outf, c) |
| 94 | } |
| 95 | outf.Close() |
| 96 | } |
| 97 | |
| 98 | func (w *wraprand) Equal(w2 *wraprand) bool { |
| 99 | return w.f32calls == w2.f32calls && |
| 100 | w.f64calls == w2.f64calls && |
| 101 | w.intncalls == w2.intncalls |
| 102 | } |
| 103 | |
| 104 | func (w *wraprand) Check(w2 *wraprand) { |
| 105 | if w.ctl != 0 && !w.Equal(w2) { |
| 106 | fmt.Fprintf(os.Stderr, "wraprand consistency check failed:\n") |
| 107 | t := "w" |
| 108 | if w.tag != "" { |
| 109 | t = w.tag |
| 110 | } |
| 111 | t2 := "w2" |
| 112 | if w2.tag != "" { |
| 113 | t2 = w2.tag |
| 114 | } |
| 115 | fmt.Fprintf(os.Stderr, " %s: {f32:%d f64:%d i:%d}\n", t, |
| 116 | w.f32calls, w.f64calls, w.intncalls) |
| 117 | fmt.Fprintf(os.Stderr, " %s: {f32:%d f64:%d i:%d}\n", t2, |
| 118 | w2.f32calls, w2.f64calls, w2.intncalls) |
| 119 | if w.ctl&RandCtlCapture != 0 { |
| 120 | f := fmt.Sprintf("/tmp/%s.txt", t) |
| 121 | f2 := fmt.Sprintf("/tmp/%s.txt", t2) |
| 122 | w.emitCalls(f) |
| 123 | w2.emitCalls(f2) |
| 124 | fmt.Fprintf(os.Stderr, "=-= emitted calls to %s, %s\n", f, f2) |
| 125 | } |
| 126 | if w.ctl&RandCtlPanic != 0 { |
| 127 | panic("bad") |
| 128 | } |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | func (w *wraprand) Checkpoint(tag string) { |
| 133 | if w.ctl&RandCtlCapture != 0 { |
| 134 | w.calls = append(w.calls, "=-=\n"+tag+"\n=-=\n") |
| 135 | } |
| 136 | } |
| 137 |
Members