Passing data from handler to middleware after serving request - api

I have the following simple API in Go:
package main
import (
"context"
"fmt"
"net/http"
"github.com/gorilla/mux"
)
func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Call the handler
next.ServeHTTP(w, r)
// Retrieve custom data from the request object after the request is served
customData := r.Context().Value("custom_data")
fmt.Println("Custom data:", customData)
})
}
func handler(w http.ResponseWriter, reqIn *http.Request) {
reqIn = reqIn.WithContext(context.WithValue(reqIn.Context(), "custom_data", true))
}
func main() {
r := mux.NewRouter()
// Attach the middleware to the router
r.Use(middleware)
// Attach the handler to the router
r.HandleFunc("/", handler).Methods("GET")
http.ListenAndServe(":8080", r)
}
I expected the context in the middleware to be able to access the value of "custom_data", but it is not able to, returning for that context value.
This happens even if I use Clone instead of WithContext for adding a value in the context of the request.
Looking around, specifically this post, if I instead use this as the handler:
func handler(w http.ResponseWriter, reqIn *http.Request) {
req := reqIn.WithContext(context.WithValue(reqIn.Context(), "custom_data", true))
*reqIn = *req
}
It works as expected.
But modifying the *http.Request is not the norm.
My real question that I am trying to solve is; how can I pass information from the handler to the middleware?
Adding a value to the context of the *http.Request would be able to be accessed in the middleware.

You can do the following:
func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
custom_data := make(map[string]any)
r = r.WithContext(context.WithValue(r.Context(), "custom_data", custom_data))
// Call the handler
next.ServeHTTP(w, r)
// Retrieve custom data from the request object after the request is served
v := r.Context().Value("custom_data")
fmt.Printf("Custom data(%T): %v\n", v, v)
// or use the above defined map directly
fmt.Printf("Custom data(%T): %v\n", custom_data, custom_data)
})
}
func handler(w http.ResponseWriter, r *http.Request) {
m, ok := r.Context().Value("custom_data").(map[string]any)
if ok && m != nil {
m["value"] = true
}
}

Related

Efficient way to make REST handlers in Go (without repeating code)?

Currently I have too much repeated code for the handlers:
type GuestMux struct {
http.ServeMux
}
func main() {
guestMux := NewGuestMux()
http.ListenAndServe(":3001", guestMux)
}
func NewGuestMux() *GuestMux {
var guestMux = &GuestMux{}
guestMux.HandleFunc("/guest/createguest", createGuestHandler)
guestMux.HandleFunc("/guest/updateguest", updateGuestHandler)
guestMux.HandleFunc("/guest/getguest", getGuestHandler)
return guestMux
}
func createGuestHandler(w http.ResponseWriter, r *http.Request) {
var createGuestReq CreateGuestRequest
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
err = json.Unmarshal(reqBody, &createGuestReq)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
return
}
resp, err := CreateGuest(&createGuestReq)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func updateGuestHandler(w http.ResponseWriter, r *http.Request) {
var updateGuestReq UpdateGuestRequest
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
err = json.Unmarshal(reqBody, &updateGuestReq)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
return
}
resp, err := UpdateGuest(&updateGuestReq)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func getGuestHandler(w http.ResponseWriter, r *http.Request) {
// almost the same as above two handlers, just different method to call and
// its parameter type
...
}
Is there any nicer way to write the handlers createGuestHandler, updateGuestHandler and getGuestHandler instead of repeating similar code blocks three times. I guess I can use interface but am not sure how to write that. I have about 20 handlers so the repeating code does not seem really maintainable.
//stackoverflow does not allow question with too much code over details so... details here, details there, even more details...//
You can move the common logic to a separate function, and pass everything to it that is specific in each handler.
Let's assume you have these types and functions:
type CreateGuestRequest struct{}
type UpdateGuestRequest struct{}
type CreateGuestResponse struct{}
type UpdateGuestResponse struct{}
func CreateGuest(v *CreateGuestRequest) (resp *CreateGuestResponse, err error) {
return nil, nil
}
func UpdateGuest(v *UpdateGuestRequest) (resp *UpdateGuestResponse, err error) {
return nil, nil
}
With generics allowed
If generics are allowed, you can factor all code out of handlers:
func handle[Req any, Resp any](w http.ResponseWriter, r *http.Request, logicFunc func(dst Req) (Resp, error)) {
var dst Req
if err := json.NewDecoder(r.Body).Decode(&dst); err != nil {
log.Printf("Decoding body failed: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
resp, err := logicFunc(dst)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Encoding response failed: %v", err)
}
}
func createGuestHandler(w http.ResponseWriter, r *http.Request) {
handle(w, r, CreateGuest)
}
func updateGuestHandler(w http.ResponseWriter, r *http.Request) {
handle(w, r, UpdateGuest)
}
As you can see, all handler implementations are just a single line! We can even get rid of the handler functions now, as we can create a handler from a logic function (like CreateGuest(), UpdateGuest()).
This is how it would look like:
func createHandler[Req any, Resp any](logicFunc func(dst Req) (Resp, error)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var dst Req
if err := json.NewDecoder(r.Body).Decode(&dst); err != nil {
log.Printf("Decoding body failed: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
resp, err := logicFunc(dst)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Encoding response failed: %v", err)
}
}
}
And using it:
func NewGuestMux() *GuestMux {
var guestMux = &GuestMux{}
guestMux.HandleFunc("/guest/createguest", createHandler(CreateGuest))
guestMux.HandleFunc("/guest/updateguest", createHandler(UpdateGuest))
return guestMux
}
Without generics
This solution does not use generics (and works with old Go versions too).
func handle(w http.ResponseWriter, r *http.Request, dst interface{}, logicFunc func() (interface{}, error)) {
if err := json.NewDecoder(r.Body).Decode(dst); err != nil {
log.Printf("Decoding body failed: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
resp, err := logicFunc()
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Encoding response failed: %v", err)
}
}
func createGuestHandler(w http.ResponseWriter, r *http.Request) {
var createGuestReq CreateGuestRequest
handle(w, r, &createGuestReq, func() (interface{}, error) {
return CreateGuest(&createGuestReq)
})
}
func updateGuestHandler(w http.ResponseWriter, r *http.Request) {
var updateGuestReq UpdateGuestRequest
handle(w, r, &updateGuestReq, func() (interface{}, error) {
return UpdateGuest(&updateGuestReq)
})
}
There are many ways to avoid repetition here, for example, you could use a decorator pattern, where you can define how to decode/encode and other steps that doesn't include your business logic.
You can check two interesting approaches:
One is from Mat: https://pace.dev/blog/2018/05/09/how-I-write-http-services-after-eight-years.html
The other one is the go-kit package (you can check it out on github), but I recommend you to checkout the idea on how to compose decorators instead of installing the library, could be an overkill for your implematation.
Typically REST APIs have just /guest endpoint with single handler that decides what to do based on HTTP method:
POST to create
GET to retrieve
PUT to update the entire record
PATCH to update certain fields
You can look at r.Method inside your handler and decide what code to run based on that.
If you are bound to interface shown in your question you can e.g. wrap handler to an anonymous function with expected interface and make it accept an additional argument to decide what to do like this:
guestMux.HandleFunc("/guest/createguest", func(w http.ResponseWriter, r *http.Request) {
guestHandler(r, w, CREATE)
})
guestMux.HandleFunc("/guest/updateguest", func(w http.ResponseWriter, r *http.Request) {
guestHandler(r, w, UPDATE)
})
...
(where CREATE and UPDATE are some sort of flags that tell guestHandler() what it should do)
I suggest to have a look to go-kit.
It's mainly designed to create services using Hexagonal architecture. It brings a lot of utility functions to avoid repeated code and focus on the business logic.
It has a lot of functionality that may not need but since it's a toolkit (and not a complete framework) you're free to use only the parts that you need.
Examples are also easy to follow.
I have these utility functions : decodeJsonBody, respondJson that I use to simplify response, without adding too much complexity. I wrap it in the Response struct for sending client side error details.
type Response struct {
Data interface{} `json:"data"`
Errors interface{} `json:"errors"`
}
func respondJson(w http.ResponseWriter, data interface{}, err error) {
w.Header().Set("Content-Type", "application/json")
if err != nil {
w.WriteHeader(http.StatusBadRequest)
err = json.NewEncoder(w).Encode(Response{
Errors: err.Error(),
})
return
}
err = json.NewEncoder(w).Encode(Response{
Data: data,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Printf("http handler failed to convert response to json %s\n", err)
}
}
func decodeJsonBody(r *http.Request, v interface{}) error {
decoder := json.NewDecoder(r.Body)
return decoder.Decode(v)
}
func updateGuestHandler(w http.ResponseWriter, r *http.Request) {
var updateGuestReq UpdateGuestRequest
err := decodeJsonBody(r, &updeateGuestReq)
if err != nil {
respondJson(w, nil, err)
return
}
data, err := UpdateGuest(&updateGuestReq)
respondJson(w, data, err)
}

How can I pass data from controller to form in go lang?

I have a handler / controller that takes in an http request.
func UpdateHandler(request *http.Request) {
ID := mux.Vars(request)["ID"]
UpdateForm.Save(ID,db)
}
Then I have a form that I want to process the data and eventually update it.
type UpdateForm struct {
ID string `json:"type"`
}
func (UpdateForm) Save(db mongo.Database) {
id := ID
repository.Update(Id)
}
Go will print out undefined ID
How can I make sure that the form gets the value from the controller?
You can populate your form using data from the request. If your request contains a JSON encoded body than you could decode it into your form object like this:
package main
import (
"encoding/json"
"net/http"
"strings"
"fmt"
)
type UpdateForm struct {
ID string `json:"type"`
}
func main() {
req, _ := http.NewRequest(
"POST",
"http://example.com",
strings.NewReader(`{"type": "foo"}`),
)
var form *UpdateForm
json.NewDecoder(req.Body).Decode(&form)
fmt.Println(form.ID) // Output: foo
}
Or you can instantiate it directly like this:
func UpdateHandler(request *http.Request) {
ID := mux.Vars(request)["ID"]
form := &UpdateForm{ID: ID}
form.Save()
}
I think it has nothing to do with the handler, but your code isn't consistent. This line
UpdateForm.Save(ID,db)
The method Save() takes two arguments, while the original method signature takes only a single mongo.Database type argument.
Here is what I assume was your intention:
type UpdateForm struct {
ID string `json:"type"`
}
func (u UpdateForm) Save(db mongo.Database) {
id := u.ID
repository.Update(id)
}
// UpdateForm instance somewhere
var u = UpdateForm{}
func UpdateHandler(request *http.Request) {
u.ID := mux.Vars(request)["ID"]
u.Save(db)
}

How to test Go function containing log.Fatal()

Say, I had the following code that prints some log messages. How would I go about testing that the correct messages have been logged? As log.Fatal calls os.Exit(1) the tests fail.
package main
import (
"log"
)
func hello() {
log.Print("Hello!")
}
func goodbye() {
log.Fatal("Goodbye!")
}
func init() {
log.SetFlags(0)
}
func main() {
hello()
goodbye()
}
Here are the hypothetical tests:
package main
import (
"bytes"
"log"
"testing"
)
func TestHello(t *testing.T) {
var buf bytes.Buffer
log.SetOutput(&buf)
hello()
wantMsg := "Hello!\n"
msg := buf.String()
if msg != wantMsg {
t.Errorf("%#v, wanted %#v", msg, wantMsg)
}
}
func TestGoodby(t *testing.T) {
var buf bytes.Buffer
log.SetOutput(&buf)
goodbye()
wantMsg := "Goodbye!\n"
msg := buf.String()
if msg != wantMsg {
t.Errorf("%#v, wanted %#v", msg, wantMsg)
}
}
This is similar to "How to test os.Exit() scenarios in Go": you need to implement your own logger, which by default redirect to log.xxx(), but gives you the opportunity, when testing, to replace a function like log.Fatalf() with your own (which does not call os.Exit(1))
I did the same for testing os.Exit() calls in exit/exit.go:
exiter = New(func(int) {})
exiter.Exit(3)
So(exiter.Status(), ShouldEqual, 3)
(here, my "exit" function is an empty one which does nothing)
While it's possible to test code that contains log.Fatal, it is not recommended. In particular you cannot test that code in a way that is supported by the -cover flag on go test.
Instead it is recommended that you change your code to return an error instead of calling log.Fatal. In a sequential function you can add an additional return value, and in a goroutine you can pass an error on a channel of type chan error (or some struct type containing a field of type error).
Once that change is made your code will be much easier to read, much easier to test, and it will be more portable (now you can use it in a server program in addition to command line tools).
If you have log.Println calls I also recommend passing a custom logger as a field on a receiver. That way you can log to the custom logger, which you can set to stderr or stdout for a server, and a noop logger for tests (so you don't get a bunch of unnecessary output in your tests). The log package supports custom loggers, so there's no need to write your own or import a third party package for this.
If you're using logrus, there's now an option to define your exit function from v1.3.0 introduced in this commit. So your test may look something like:
func Test_X(t *testing.T) {
cases := []struct{
param string
expectFatal bool
}{
{
param: "valid",
expectFatal: false,
},
{
param: "invalid",
expectFatal: true,
},
}
defer func() { log.StandardLogger().ExitFunc = nil }()
var fatal bool
log.StandardLogger().ExitFunc = func(int){ fatal = true }
for _, c := range cases {
fatal = false
X(c.param)
assert.Equal(t, c.expectFatal, fatal)
}
}
I have using the following code to test my function. In xxx.go:
var logFatalf = log.Fatalf
if err != nil {
logFatalf("failed to init launcher, err:%v", err)
}
And in xxx_test.go:
// TestFatal is used to do tests which are supposed to be fatal
func TestFatal(t *testing.T) {
origLogFatalf := logFatalf
// After this test, replace the original fatal function
defer func() { logFatalf = origLogFatalf } ()
errors := []string{}
logFatalf = func(format string, args ...interface{}) {
if len(args) > 0 {
errors = append(errors, fmt.Sprintf(format, args))
} else {
errors = append(errors, format)
}
}
if len(errors) != 1 {
t.Errorf("excepted one error, actual %v", len(errors))
}
}
I'd use the supremely handy bouk/monkey package (here along with stretchr/testify).
func TestGoodby(t *testing.T) {
wantMsg := "Goodbye!"
fakeLogFatal := func(msg ...interface{}) {
assert.Equal(t, wantMsg, msg[0])
panic("log.Fatal called")
}
patch := monkey.Patch(log.Fatal, fakeLogFatal)
defer patch.Unpatch()
assert.PanicsWithValue(t, "log.Fatal called", goodbye, "log.Fatal was not called")
}
I advise reading the caveats to using bouk/monkey before going this route.
There used to be an answer here that I referred to, looks like it got deleted. It was the only one I've seen where you could have passing tests without modifying dependencies or otherwise touching the code that should Fatal.
I agree with other answers that this is usually an inappropriate test. Usually you should rewrite the code under test to return an error, test the error is returned as expected, and Fatal at a higher level scope after observing the non-nil error.
To OP's question of testing that the that the correct messages have been logged, you would inspect inner process's cmd.Stdout.
https://play.golang.org/p/J8aiO9_NoYS
func TestFooFatals(t *testing.T) {
fmt.Println("TestFooFatals")
outer := os.Getenv("FATAL_TESTING") == ""
if outer {
fmt.Println("Outer process: Spawning inner `go test` process, looking for failure from fatal")
cmd := exec.Command(os.Args[0], "-test.run=TestFooFatals")
cmd.Env = append(os.Environ(), "FATAL_TESTING=1")
// cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
err := cmd.Run()
fmt.Printf("Outer process: Inner process returned %v\n", err)
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
// fmt.Println("Success: inner process returned 1, passing test")
return
}
t.Fatalf("Failure: inner function returned %v, want exit status 1", err)
} else {
// We're in the spawned process.
// Do something that should fatal so this test fails.
foo()
}
}
// should fatal every time
func foo() {
log.Printf("oh my goodness, i see %q\n", os.Getenv("FATAL_TESTING"))
// log.Fatal("oh my gosh")
}
I've combined answers from different sources to produce this:
import (
"bufio"
"bytes"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"os/user"
"strings"
"testing"
"bou.ke/monkey"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestCommandThatErrors(t *testing.T) {
fakeExit := func(int) {
panic("os.Exit called")
}
patch := monkey.Patch(os.Exit, fakeExit)
defer patch.Unpatch()
var buf bytes.Buffer
log.SetOutput(&buf)
for _, tc := range []struct {
cliArgs []string
expectedError string
}{
{
cliArgs: []string{"dev", "api", "--dockerless"},
expectedError: "Some services don't have dockerless variants implemented yet.",
},
} {
t.Run(strings.Join(tc.cliArgs, " "), func(t *testing.T) {
harness := createTestApp()
for _, cmd := range commands {
cmd(harness.app)
}
assert.Panics(t, func() { harness.app.run(tc.cliArgs) })
assert.Contains(t, buf.String(), tc.expectedError)
buf.Reset()
})
}
}
Works great :)
You cannot and you should not.
This "you must 'test' each and every line"-attitude is strange, especially for terminal conditions and that's what log.Fatal is for.
(Or just test it from the outside.)

Go method chaining and error handling

I want to create a method chaining API in Go. In all examples I can find the chained operations seem always to succeed which I can't guarantee. I therefore try to extend these to add the error return value.
If I do it like this
package main
import "fmt"
type Chain struct {
}
func (v *Chain)funA() (*Chain, error ) {
fmt.Println("A")
return v, nil
}
func (v *Chain)funB() (*Chain, error) {
fmt.Println("B")
return v, nil
}
func (v *Chain)funC() (*Chain, error) {
fmt.Println("C")
return v, nil
}
func main() {
fmt.Println("Hello, playground")
c := Chain{}
d, err := c.funA().funB().funC() // line 24
}
The compiler tells me chain-err-test.go:24: multiple-value c.funA() in single-value context and won't compile. Is there a good way so funcA, funcB and funcC can report an error and stop that chain?
Is there a good way so funcA, funcB and funcC can report an error and stop that chain?
Unfortunately, no, there is no good solution to your problem. Workarounds are sufficiently complex (adding in error channels, etc) that the cost exceeds the gain.
Method chaining isn't an idiom in Go (at least not for methods that can possibly error). This isn't because there is anything particularly wrong with method chains, but a consequence of the idiom of returning errors instead of panicking. The other answers are workarounds, but none are idiomatic.
Can I ask, is it not idiomatic to chain methods in Go because of the consequence of returning error as we do in Go, or is it more generally a consequence of having multiple method returns?
Good question, but it's not because Go supports multiple returns. Python supports multiple returns, and Java can too via a Tuple<T1, T2> class; method chains are common in both languages. The reason these languages can get away with it is because they idiomatically communicate errors via exceptions. Exceptions stop the method chain immediately and jump to the relevant exception handler. This is the behavior the Go developers were specifically trying to avoid by choosing to return errors instead.
You can try like that:
https://play.golang.org/p/dVn_DGWt1p_H
package main
import (
"errors"
"fmt"
)
type Chain struct {
err error
}
func (v *Chain) funA() *Chain {
if v.err != nil {
return v
}
fmt.Println("A")
return v
}
func (v *Chain) funB() *Chain {
if v.err != nil {
return v
}
v.err = errors.New("error at funB")
fmt.Println("B")
return v
}
func (v *Chain) funC() *Chain {
if v.err != nil {
return v
}
fmt.Println("C")
return v
}
func main() {
c := Chain{}
d := c.funA().funB().funC()
fmt.Println(d.err)
}
If you have control over the code and the function signature is identical you can write something like:
func ChainCall(fns ...func() (*Chain, error)) (err error) {
for _, fn := range fns {
if _, err = fn(); err != nil {
break
}
}
return
}
playground
You can make your chain lazy by collecting a slice of funtions
package main
import (
"fmt"
)
type (
chainFunc func() error
funcsChain struct {
funcs []chainFunc
}
)
func Chain() funcsChain {
return funcsChain{}
}
func (chain funcsChain) Say(s string) funcsChain {
f := func() error {
fmt.Println(s)
return nil
}
return funcsChain{append(chain.funcs, f)}
}
func (chain funcsChain) TryToSay(s string) funcsChain {
f := func() error {
return fmt.Errorf("don't speek golish")
}
return funcsChain{append(chain.funcs, f)}
}
func (chain funcsChain) Execute() (i int, err error) {
for i, f := range chain.funcs {
if err := f(); err != nil {
return i, err
}
}
return -1, nil
}
func main() {
i, err := Chain().
Say("Hello, playground").
TryToSay("go cannot into chains").
Execute()
fmt.Printf("i: %d, err: %s", i, err)
}
You don't actually need channels and/or contexts to get something like this to work. I think this implementation meets all your requirements but needless to say, this leaves a sour taste. Go is not a functional language and it's best not to treat it as such.
package main
import (
"errors"
"fmt"
"strconv"
)
type Res[T any] struct {
Val T
Halt bool
Err error
}
// executes arguments until a halting signal is detected
func (r *Res[T]) Chain(args ...func() *Res[T]) *Res[T] {
temp := r
for _, f := range args {
if temp = f(); temp.Halt {
break
}
}
return temp
}
// example function, converts any type -> string -> int -> string
func (r *Res[T]) funA() *Res[string] {
s := fmt.Sprint(r.Val)
i, err := strconv.Atoi(s)
if err != nil {
r.Err = fmt.Errorf("wrapping error: %w", err)
}
fmt.Println("the function down the pipe is forced to work with Res[string]")
return &Res[string]{Val: strconv.Itoa(i), Err: r.Err}
}
func (r *Res[T]) funB() *Res[T] {
prev := errors.Unwrap(r.Err)
fmt.Printf("unwrapped error: %v\n", prev)
// signal a halt if something is wrong
if prev != nil {
r.Halt = true
}
return r
}
func (r *Res[T]) funC() *Res[T] {
fmt.Println("this one never gets executed...")
return r
}
func (r *Res[T]) funD() *Res[T] {
fmt.Println("...but this one does")
return r
}
func funE() *Res[string] {
fmt.Println("Chain can even take non-methods, but beware of nil returns")
return nil
}
func main() {
r := Res[string]{}
r.Chain(r.funA, r.funB, r.funC).funD().Chain(funE).funC() // ... and so on
}
How about this approach: Create a struct that delegates Chain and error, and return it instead of two values. e.g.:
package main
import "fmt"
type Chain struct {
}
type ChainAndError struct {
*Chain
error
}
func (v *Chain)funA() ChainAndError {
fmt.Println("A")
return ChainAndError{v, nil}
}
func (v *Chain)funB() ChainAndError {
fmt.Println("B")
return ChainAndError{v, nil}
}
func (v *Chain)funC() ChainAndError {
fmt.Println("C")
return ChainAndError{v, nil}
}
func main() {
fmt.Println("Hello, playground")
c := Chain{}
result := c.funA().funB().funC() // line 24
fmt.Println(result.error)
}

How to test os.exit scenarios in Go

Given this code
func doomed() {
os.Exit(1)
}
How do I properly test that calling this function will result in an exit using go test? This needs to occur within a suite of tests, in other words the os.Exit() call cannot impact the other tests and should be trapped.
There's a presentation by Andrew Gerrand (one of the core members of the Go team) where he shows how to do it.
Given a function (in main.go)
package main
import (
"fmt"
"os"
)
func Crasher() {
fmt.Println("Going down in flames!")
os.Exit(1)
}
here's how you would test it (through main_test.go):
package main
import (
"os"
"os/exec"
"testing"
)
func TestCrasher(t *testing.T) {
if os.Getenv("BE_CRASHER") == "1" {
Crasher()
return
}
cmd := exec.Command(os.Args[0], "-test.run=TestCrasher")
cmd.Env = append(os.Environ(), "BE_CRASHER=1")
err := cmd.Run()
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
return
}
t.Fatalf("process ran with err %v, want exit status 1", err)
}
What the code does is invoke go test again in a separate process through exec.Command, limiting execution to the TestCrasher test (via the -test.run=TestCrasher switch). It also passes in a flag via an environment variable (BE_CRASHER=1) which the second invocation checks for and, if set, calls the system-under-test, returning immediately afterwards to prevent running into an infinite loop. Thus, we are being dropped back into our original call site and may now validate the actual exit code.
Source: Slide 23 of Andrew's presentation. The second slide contains a link to the presentation's video as well.
He talks about subprocess tests at 47:09
I do this by using bouk/monkey:
func TestDoomed(t *testing.T) {
fakeExit := func(int) {
panic("os.Exit called")
}
patch := monkey.Patch(os.Exit, fakeExit)
defer patch.Unpatch()
assert.PanicsWithValue(t, "os.Exit called", doomed, "os.Exit was not called")
}
monkey is super-powerful when it comes to this sort of work, and for fault injection and other difficult tasks. It does come with some caveats.
I don't think you can test the actual os.Exit without simulating testing from the outside (using exec.Command) process.
That said, you might be able to accomplish your goal by creating an interface or function type and then use a noop implementation in your tests:
Go Playground
package main
import "os"
import "fmt"
type exiter func (code int)
func main() {
doExit(func(code int){})
fmt.Println("got here")
doExit(func(code int){ os.Exit(code)})
}
func doExit(exit exiter) {
exit(1)
}
You can't, you would have to use exec.Command and test the returned value.
Code for testing:
package main
import "os"
var my_private_exit_function func(code int) = os.Exit
func main() {
MyAbstractFunctionAndExit(1)
}
func MyAbstractFunctionAndExit(exit int) {
my_private_exit_function(exit)
}
Testing code:
package main
import (
"os"
"testing"
)
func TestMyAbstractFunctionAndExit(t *testing.T) {
var ok bool = false // The default value can be omitted :)
// Prepare testing
my_private_exit_function = func(c int) {
ok = true
}
// Run function
MyAbstractFunctionAndExit(1)
// Check
if ok == false {
t.Errorf("Error in AbstractFunction()")
}
// Restore if need
my_private_exit_function = os.Exit
}
To test the os.Exit like scenarios we can use the https://github.com/undefinedlabs/go-mpatch along with the below code. This ensures that your code remains clean as well as readable and maintainable.
type PatchedOSExit struct {
Called bool
CalledWith int
patchFunc *mpatch.Patch
}
func PatchOSExit(t *testing.T, mockOSExitImpl func(int)) *PatchedOSExit {
patchedExit := &PatchedOSExit{Called: false}
patchFunc, err := mpatch.PatchMethod(os.Exit, func(code int) {
patchedExit.Called = true
patchedExit.CalledWith = code
mockOSExitImpl(code)
})
if err != nil {
t.Errorf("Failed to patch os.Exit due to an error: %v", err)
return nil
}
patchedExit.patchFunc = patchFunc
return patchedExit
}
func (p *PatchedOSExit) Unpatch() {
_ = p.patchFunc.Unpatch()
}
You can consume the above code as follows:
func NewSampleApplication() {
os.Exit(101)
}
func Test_NewSampleApplication_OSExit(t *testing.T) {
// Prepare mock setup
fakeExit := func(int) {}
p := PatchOSExit(t, fakeExit)
defer p.Unpatch()
// Call the application code
NewSampleApplication()
// Assert that os.Exit gets called
if p.Called == false {
t.Errorf("Expected os.Exit to be called but it was not called")
return
}
// Also, Assert that os.Exit gets called with the correct code
expectedCalledWith := 101
if p.CalledWith != expectedCalledWith {
t.Errorf("Expected os.Exit to be called with %d but it was called with %d", expectedCalledWith, p.CalledWith)
return
}
}
I've also added a link to Playground: https://go.dev/play/p/FA0dcwVDOm7