Issue with go-sqlmock testing in the expected query part - go-gorm

I am using go-sqlmock for the first time and I am trying to write a test for post operation. I am using gorm and gin.
The test is giving me an error where s.mock.ExpectQuery(regexp.QuoteMeta(.... I am not what is the issue here. I have posted both the test and the output.
Also, (this has nothing to do with 1) in this test I really do not know what the code will be as it is randomly generated in the api controller. Is there a way to assign a generic number in the code field.
The test file
package unit
import (
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/SamiAlsubhi/go/controllers"
"github.com/SamiAlsubhi/go/routes"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type Suite struct {
suite.Suite
DB *gorm.DB
mock sqlmock.Sqlmock
router *gin.Engine
}
func (s *Suite) SetupSuite(t *testing.T) {
conn, mock, err := sqlmock.New()
if err != nil || conn == nil {
t.Errorf("Failed to open mock sql db, got error: %v", err)
}
s.mock = mock
dialector := postgres.New(postgres.Config{
DSN: "sqlmock_db_0",
DriverName: "postgres",
Conn: conn,
PreferSimpleProtocol: true,
})
if db, err := gorm.Open(dialector, &gorm.Config{}); err != nil || db == nil {
t.Errorf("Failed to open gorm v2 db, got error: %v", err)
} else {
s.DB = db
}
api := &controllers.API{Db: s.DB}
s.router = routes.SetupRouter(api)
}
func TestSetup(t *testing.T) {
suite.Run(t, new(Suite))
}
func (s *Suite) AfterTest(_, _ string) {
require.NoError(s.T(), s.mock.ExpectationsWereMet())
}
func (s *Suite) Test_GetOTP() {
var (
phone = "99999999"
code = "123456"
)
s.mock.ExpectQuery(regexp.QuoteMeta(
`INSERT INTO "otps" ("phone","code") VALUES ($1,$2) RETURNING "otps"."id"`)).
WithArgs(phone, code).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(1))
s.mock.ExpectCommit()
w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/api/auth/get-otp/"+phone, nil)
require.NoError(s.T(), err)
s.router.ServeHTTP(w, req)
assert.Equal(s.T(), 200, w.Code)
//require.Nil(s.T(), deep.Equal(&model.Person{ID: id, Name: name}, w.Body))
}
the output.
--- FAIL: TestSetup (0.00s)
--- FAIL: TestSetup/Test_GetOTP (0.00s)
/Users/sami/Desktop/SamiAlsubhi/go/test/unit/suite.go:63: test panicked: runtime error: invalid memory address or nil pointer dereference
goroutine 26 [running]:
runtime/debug.Stack()
/usr/local/go/src/runtime/debug/stack.go:24 +0x65
github.com/stretchr/testify/suite.failOnPanic(0xc000001a00)
/Users/sami/Desktop/golang/pkg/mod/github.com/stretchr/testify#v1.7.0/suite/suite.go:63 +0x3e
panic({0x49e96a0, 0x5193810})
/usr/local/go/src/runtime/panic.go:1038 +0x215
github.com/SamiAlsubhi/go/test/unit.(*Suite).AfterTest(0x4abe61b, {0x4becfd0, 0xc000468940}, {0x0, 0x0})
/Users/sami/Desktop/SamiAlsubhi/go/test/unit/setup_test.go:60 +0x1c
github.com/stretchr/testify/suite.Run.func1.1()
/Users/sami/Desktop/golang/pkg/mod/github.com/stretchr/testify#v1.7.0/suite/suite.go:137 +0x1b7
panic({0x49e96a0, 0x5193810})
/usr/local/go/src/runtime/panic.go:1038 +0x215
github.com/SamiAlsubhi/go/test/unit.(*Suite).Test_GetOTP(0xc000468940)
/Users/sami/Desktop/SamiAlsubhi/go/test/unit/setup_test.go:69 +0x4f
reflect.Value.call({0xc000049140, 0xc000010308, 0x13}, {0x4abf50c, 0x4}, {0xc000080e70, 0x1, 0x1})
/usr/local/go/src/reflect/value.go:543 +0x814
reflect.Value.Call({0xc000049140, 0xc000010308, 0xc000468940}, {0xc0003c9e70, 0x1, 0x1})
/usr/local/go/src/reflect/value.go:339 +0xc5
github.com/stretchr/testify/suite.Run.func1(0xc000001a00)
/Users/sami/Desktop/golang/pkg/mod/github.com/stretchr/testify#v1.7.0/suite/suite.go:158 +0x4b6
testing.tRunner(0xc000001a00, 0xc000162000)
/usr/local/go/src/testing/testing.go:1259 +0x102
created by testing.(*T).Run
/usr/local/go/src/testing/testing.go:1306 +0x35a
FAIL
coverage: [no statements]
FAIL github.com/SamiAlsubhi/go/test/unit 0.912s
FAIL

Solution to the first issue:
when using testify/suite, There are bunch of methods if created for the Suite struct, they will be automatically executed when running the test. That being said, These methods will pass through an interface filter. In the case of .SetupSuite, it has to have NO arguments and No return, in order to run.
Solution to the second issue:
There is a way in go-sqlmock to match any kind of data by using sqlmock.AnyArg().
Fixed code:
package unit
import (
"encoding/json"
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/SamiAlsubhi/go/controllers"
"github.com/SamiAlsubhi/go/routes"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type Suite struct {
suite.Suite
DB *gorm.DB
mock sqlmock.Sqlmock
router *gin.Engine
}
func (s *Suite) SetupSuite() {
//t.Logf("setup start")
conn, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp))
if err != nil || conn == nil {
panic(fmt.Sprintf("Failed to open mock sql db, got error: %v", err))
}
s.mock = mock
dialector := postgres.New(postgres.Config{
DSN: "sqlmock_db_0",
DriverName: "postgres",
Conn: conn,
PreferSimpleProtocol: true,
})
if db, err := gorm.Open(dialector, &gorm.Config{SkipDefaultTransaction: true}); err != nil || db == nil {
panic(fmt.Sprintf("Failed to open gorm v2 db, got error: %v", err))
} else {
s.DB = db
}
api := &controllers.API{Db: s.DB, IsTesting: true}
s.router = routes.SetupRouter(api)
}
func TestSetup(t *testing.T) {
suite.Run(t, new(Suite))
}
// func (s *Suite) AfterTest(_, _ string) {
// require.NoError(s.T(), s.mock.ExpectationsWereMet())
// }
func (s *Suite) Test_GetOTP_Non_Existing_Phone() {
/*
This to test getting OTP for a phone number that does not exist in the otps table
*/
phone := fmt.Sprintf("%v", 90000000+rand.Intn(99999999-90000000))
s.mock.MatchExpectationsInOrder(false)
s.mock.ExpectQuery(regexp.QuoteMeta(`SELECT count(*) FROM "otps" WHERE phone = $1 AND "otps"."deleted_at" IS NULL`)).
WithArgs(phone).
WillReturnRows(sqlmock.NewRows([]string{"count"}).
AddRow(0))
s.mock.ExpectQuery(regexp.QuoteMeta(
`INSERT INTO "otps" ("created_at","updated_at","deleted_at","phone","code") VALUES ($1,$2,$3,$4,$5) RETURNING "id"`)).
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), phone, sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(1))
w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/api/auth/get-otp/"+phone, nil)
require.NoError(s.T(), err)
s.router.ServeHTTP(w, req)
assert.Equal(s.T(), 200, w.Code)
//parse response
var response gin.H
err = json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(s.T(), err)
_, ok := response["expiry_in"]
assert.True(s.T(), ok)
require.NoError(s.T(), s.mock.ExpectationsWereMet())
}

Related

How to optimize database connections

In my Go application I use crontab package to run Tracker function every minute. As you can notice from the code I call PostgreSQL function. To interact with the PostgreSQL database, I use the gorm package. Application worked several days without any problem but now I notice an error in logs: pq: sorry, too many clients already. I know that same questions was asked several times in StackOverflow before. For example in this post people advice to use Exec or Scan methods. In my case as you can see I use Exec method but anyway I have error. As far as I understand, each database request makes a separate connection and does not close it. I can't figure out what I'm doing wrong.
main.go:
package main
import (
"github.com/mileusna/crontab"
)
func main() {
database.ConnectPostgreSQL()
defer database.DisconnectPostgreSQL()
err = crontab.New().AddJob("* * * * *", controllers.Tracker); if err != nil {
utils.Logger().Fatal(err)
return
}
}
tracker.go:
package controllers
import (
"questionnaire/database"
"time"
)
var Tracker = func() {
err := database.DBGORM.Exec("CALL tracker($1)", time.Now().Format("2006-01-02 15:04:05")).Error; if err != nil {
utils.Logger().Println(err) // ERROR: pq: sorry, too many clients already
return
}
}
PostgreSQL.go:
package database
import (
"fmt"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/postgres"
"github.com/joho/godotenv"
"questionnaire/utils"
)
var DBGORM *gorm.DB
func ConnectPostgreSQL() {
err := godotenv.Load(".env")
if err != nil {
utils.Logger().Println(err)
panic(err)
}
databaseUser := utils.CheckEnvironmentVariable("PostgreSQL_USER")
databasePassword := utils.CheckEnvironmentVariable("PostgreSQL_PASSWORD")
databaseHost := utils.CheckEnvironmentVariable("PostgreSQL_HOST")
databaseName := utils.CheckEnvironmentVariable("PostgreSQL_DATABASE_NAME")
databaseURL:= fmt.Sprintf("host=%s user=%s dbname=%s password=%s sslmode=disable", databaseHost, databaseUser, databaseName, databasePassword)
DBGORM, err = gorm.Open("postgres", databaseURL)
if err != nil {
utils.Logger().Println(err)
panic(err)
}
err = DBGORM.DB().Ping()
if err != nil {
utils.Logger().Println(err)
panic(err)
}
DBGORM.LogMode(true)
}
func DisconnectPostgreSQL() error {
return DBGORM.Close()
}

What's the best way to delete all database records used for a Test suite?

I have a test suite that pollute my database using a seed read from a YAML file.
I'm wondering is there a way to clean my database (delete all records used for the test suite) after running my tests.
// Open db and returns pointer and closer func
func prepareMySQLDB(t *testing.T) (db *sql.DB, closer func() error) {
db, err := sql.Open("mysql", "user:pass#/database")
if err != nil {
t.Fatalf("open mysql connection: %s", err)
}
return db, db.Close
}
// Pollute my database
func polluteDb(db *sql.DB, t *testing.T) {
seed, err := os.Open("seed.yml")
if err != nil {
t.Fatalf("failed to open seed file: %s", err)
}
defer seed.Close()
p := polluter.New(polluter.MySQLEngine(db))
if err := p.Pollute(seed); err != nil {
t.Fatalf("failed to pollute: %s", err)
}
}
func TestAllUsers(t *testing.T) {
t.Parallel()
db, closeDb := prepareMySQLDB(t)
defer closeDb()
polluteDb(db, t)
users, err := AllUsersD(db)
if err != nil {
t.Fatal("AllUsers() failed")
}
got := users[0].Email
if got != "myemail#gmail.com" {
t.Errorf("AllUsers().Email = %s; want myemail#gmail.com", got)
}
got1 := len(users)
if got1 != 1 {
t.Errorf("len(AllUsers()) = %d; want 1", got1)
}
}
// Test I'm interested in
func TestAddUser(t *testing.T) {
t.Parallel()
db, closeDb := prepareMySQLDB(t)
defer closeDb()
polluteDb(db, t)
user, err := AddUser(...)
if err != nil {
t.Fatal("AddUser() failed")
}
//how can I clean my database after this?
}
Should I retrieve the last ID inserted in TestAddUser() and just delete that line manually or there's any other way to save my database state and retrieve it after?
As I said I'm new to Go so any other comments on my code or what so ever are strongly appreciated.
The best way is usually to use a transaction, then ROLLBACK, so they are never committed in the first place.
The github.com/DATA-DOG/go-txdb package can help a lot with that.
Final code:
import (
"database/sql"
"os"
"testing"
txdb "github.com/DATA-DOG/go-txdb"
"github.com/romanyx/polluter"
)
//mostly sql tests
func init() {
txdb.Register("txdb", "mysql", "root:root#/betell_rest")
}
func TestAddUser(t *testing.T) {
db, err := sql.Open("txdb", "root:root#/betell_rest")
if err != nil {
t.Fatal(err)
}
defer db.Close()
users, _ := AllUsers(db)
userscount := len(users)
err = AddUser(db, "bla#gmail.com", "pass")
if err != nil {
t.Fatal("AddUser() failed")
}
users, _ = AllUsers(db)
if (userscount + 1) != len(users) {
t.Fatal("AddUser() failed to write in database")
}
}
Note: Also you can pass db into your polluter so you don't affect your database at all.

Mocking functions in Golang to test my http routes

I'm totally confused figuring out how I can mock a function, without using any additional packages like golang/mock. I'm just trying to learn how to do so but can't find many decent online resources.
Essentially, I followed this excellent article that explains how to use an interface to mock things.
As so, I've re-written the function I wanted to test. The function just inserts some data into datastore. My tests for that are ok - I can mock the function directly.
The issue I'm having is mocking it 'within' an http route I'm trying to test. Am using the Gin framework.
My router (simplified) looks like this:
func SetupRouter() *gin.Engine {
r := gin.Default()
r.Use(gin.Logger())
r.Use(gin.Recovery())
v1 := r.Group("v1")
v1.PATCH("operations/:id", controllers.UpdateOperation)
}
Which calls the UpdateOperation function:
func UpdateOperation(c *gin.Context) {
id := c.Param("id")
r := m.Response{}
str := m.OperationInfoer{}
err := m.FindAndCompleteOperation(str, id, r.Report)
if err == nil {
c.JSON(200, gin.H{
"message": "Operation completed",
})
}
}
So, I need to mock the FindAndCompleteOperation() function.
The main (simplified) functions looks like this:
func (oi OperationInfoer) FindAndCompleteOp(id string, report Report) error {
ctx := context.Background()
q := datastore.NewQuery("Operation").
Filter("Unique_Id =", id).
Limit(1)
var ops []Operation
if ts, err := db.Datastore.GetAll(ctx, q, &ops); err == nil {
{
if len(ops) > 0 {
ops[0].Id = ts[0].ID()
ops[0].Complete = true
// Do stuff
_, err := db.Datastore.Put(ctx, key, &o)
if err == nil {
log.Print("OPERATION COMPLETED")
}
}
}
}
err := errors.New("Not found")
return err
}
func FindAndCompleteOperation(ri OperationInfoer, id string, report Report) error {
return ri.FindAndCompleteOp(id, report)
}
type OperationInfoer struct{}
To test the route that updates the operation, I have something like so:
FIt("Return 200, updates operation", func() {
testRouter := SetupRouter()
param := make(url.Values)
param["access_token"] = []string{public_token}
report := m.Report{}
report.Success = true
report.Output = "my output"
jsonStr, _ := json.Marshal(report)
req, _ := http.NewRequest("PATCH", "/v1/operations/123?"+param.Encode(), bytes.NewBuffer(jsonStr))
resp := httptest.NewRecorder()
testRouter.ServeHTTP(resp, req)
Expect(resp.Code).To(Equal(200))
o := FakeResponse{}
json.NewDecoder(resp.Body).Decode(&o)
Expect(o.Message).To(Equal("Operation completed"))
})
Originally, I tried to cheat a bit and just tried something like this:
m.FindAndCompleteOperation = func(string, m.Report) error {
return nil
}
But that affects all the other tests etc.
I'm hoping someone can explain simply what the best way to mock the FindAndCompleteOperation function so I can test the routes, without relying on datastore etc.
I have another relevant, more informative answer to a similar question here, but here's an answer for your specific scenario:
Update your SetupRouter() function to take a function that can either be the real FindAndCompleteOperation function or a stub function:
Playground
package main
import "github.com/gin-gonic/gin"
// m.Response.Report
type Report struct {
// ...
}
// m.OperationInfoer
type OperationInfoer struct {
// ...
}
type findAndComplete func(s OperationInfoer, id string, report Report) error
func FindAndCompleteOperation(OperationInfoer, string, Report) error {
// ...
return nil
}
func SetupRouter(f findAndComplete) *gin.Engine {
r := gin.Default()
r.Group("v1").PATCH("/:id", func(c *gin.Context) {
if f(OperationInfoer{}, c.Param("id"), Report{}) == nil {
c.JSON(200, gin.H{"message": "Operation completed"})
}
})
return r
}
func main() {
r := SetupRouter(FindAndCompleteOperation)
if err := r.Run(":8080"); err != nil {
panic(err)
}
}
Test/mocking example
package main
import (
"encoding/json"
"net/http/httptest"
"strings"
"testing"
)
func TestUpdateRoute(t *testing.T) {
// build findAndComplete stub
var callCount int
var lastInfoer OperationInfoer
var lastID string
var lastReport Report
stub := func(s OperationInfoer, id string, report Report) error {
callCount++
lastInfoer = s
lastID = id
lastReport = report
return nil // or `fmt.Errorf("Err msg")` if you want to test fault path
}
// invoke endpoint
w := httptest.NewRecorder()
r := httptest.NewRequest(
"PATCH",
"/v1/id_value",
strings.NewReader(""),
)
SetupRouter(stub).ServeHTTP(w, r)
// check that the stub was invoked correctly
if callCount != 1 {
t.Fatal("Wanted 1 call; got", callCount)
}
if lastInfoer != (OperationInfoer{}) {
t.Fatalf("Wanted %v; got %v", OperationInfoer{}, lastInfoer)
}
if lastID != "id_value" {
t.Fatalf("Wanted 'id_value'; got '%s'", lastID)
}
if lastReport != (Report{}) {
t.Fatalf("Wanted %v; got %v", Report{}, lastReport)
}
// check that the correct response was returned
if w.Code != 200 {
t.Fatal("Wanted HTTP 200; got HTTP", w.Code)
}
var body map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatal("Unexpected error:", err)
}
if body["message"] != "Operation completed" {
t.Fatal("Wanted 'Operation completed'; got %s", body["message"])
}
}
You can't mock if you use globals that can't be mocked in an handler. Either your globals are mockable (i.e. declared as variables of interface type) or you need to use dependency injection.
func (oi OperationInfoer) FindAndCompleteOp(id string, report Report) error {...}
looks like a method of a struct, so you should be able to inject that struct into an handler, at the very least.
type OperationInfoer interface {
FindAndCompleteOp(id string, report Report) error
}
type ConcreteOperationInfoer struct { /* actual implementation */ }
func UpdateOperation(oi OperationInfoer) func(c *gin.Context) {
return func (c *gin.Context){
// the code
}
}
then mocking becomes a breeze in your tests :
UpdateOperation(mockOperationInfoer)(ginContext)
You can use a struct instead of closures
type UpdateOperationHandler struct {
Oi OperationInfoer
}
func (h UpdateOperationHandler ) UpdateOperation (c *gin.Context) {
h.Oi.FindAndCompleteOp(/* code */ )
}

Golang test with channels does not exit

The following Golang test never exits. I suspect it has something to do with a channel deadlock but being a go-noob, I am not very certain.
const userName = "xxxxxxxxxxxx"
func TestSynchroninze(t *testing.T) {
c, err := channel.New(github.ChannelName, authToken)
if err != nil {
t.Fatalf("Could not create channel: %s", err)
return
}
state := channel.NewState(nil)
ctx := context.Background()
ctx = context.WithValue(ctx, "userId", userName)
user := api.User{}
output, errs := c.Synchronize(state, ctx)
if err = <-errs; err != nil {
t.Fatalf("Error performing synchronize: %s", err)
return
}
for o := range output {
switch oo := o.Data.(type) {
case api.User:
user = oo
glog.Infof("we have a USER %s\n", user)
default:
t.Errorf("Encountered unexpected data type: %T", oo)
}
}
}
Here are the methods being tested.
type github struct {
client *api.Client
}
func newImplementation(t auth.UserToken) implementation.Implementation {
return &github{client: api.NewClient(t)}
}
// -------------------------------------------------------------------------------------
const (
kLastUserFetch = "lastUserFetch"
)
type synchronizeFunc func(implementation.MutableState, chan *implementation.Output, context.Context) error
// -------------------------------------------------------------------------------------
func (g *github) Synchronize(state implementation.MutableState, ctx context.Context) (<-chan *implementation.Output, <-chan error) {
output := make(chan *implementation.Output)
errors := make(chan error, 1) // buffer allows preflight errors
// Close output channels once we're done
defer func() {
go func() {
// wg.Wait()
close(errors)
close(output)
}()
}()
err := g.fetchUser(state, output, ctx)
if err != nil {
errors <- err
}
return output, errors
}
func (g *github) fetchUser(state implementation.MutableState, output chan *implementation.Output, ctx context.Context) error {
var err error
var user = api.User{}
userId, _ := ctx.Value("userId").(string)
user, err = g.client.GetUser(userId, ctx.Done())
if err == nil {
glog.Info("No error in fetchUser")
output <- &implementation.Output{Data: user}
state.SetTime(kLastUserFetch, time.Now())
}
return err
}
func (c *Client) GetUser(id string, quit <-chan struct{}) (user User, err error) {
// Execute request
var data []byte
data, err = c.get("users/"+id, nil, quit)
glog.Infof("USER DATA %s", data)
// Parse response
if err == nil && len(data) > 0 {
err = json.Unmarshal(data, &user)
data, _ = json.Marshal(user)
}
return
}
Here is what I see in the console (most of the user details removed)
I1228 13:25:05.291010 21313 client.go:177] GET https://api.github.com/users/xxxxxxxx
I1228 13:25:06.010085 21313 client.go:36] USER DATA {"login":"xxxxxxxx","id":00000000,"avatar_url":"https://avatars.githubusercontent.com/u/0000000?v=3",...}
I1228 13:25:06.010357 21313 github.go:90] No error in fetchUser
==========EDIT=============
Here is the relevant portion of the api package.
package api
type Client struct {
authToken auth.UserToken
http *http.Client
}
func NewClient(authToken auth.UserToken) *Client {
return &Client{
authToken: authToken,
http: auth.NewClient(authToken),
}
}
// -------------------------------------------------------------------------------------
type User struct {
Id int `json:"id,omitempty"`
Username string `json:"login,omitempty"`
Email string `json:"email,omitempty"`
FullName string `json:"name,omitempty"`
ProfilePicture string `json:"avatar_url,omitempty"`
Bio string `json:"bio,omitempty"`
Website string `json:"blog,omitempty"`
Company string `json:"company,omitempty"`
}
And the channel package
package channel
type Channel struct {
implementation.Descriptor
imp implementation.Implementation
}
// New returns a channel implementation with a given name and auth token.
func New(name string, token auth.UserToken) (*Channel, error) {
if desc, ok := implementation.Lookup(name); ok {
if imp := implementation.New(name, token); imp != nil {
return &Channel{Descriptor: desc, imp: imp}, nil
}
}
return nil, ErrInvalidChannel
}
and the implementation package...
package implementation
import "golang.org/x/net/context"
// -------------------------------------------------------------------------------------
// Implementation is the interface implemented by subpackages.
type Implementation interface {
// Synchronize performs a synchronization using the given state. A context parameters
// is provided to provide cancellation as well as implementation-specific behaviors.
//
// If a fatal error occurs (see package error definitions), the state can be discarded
// to prevent the persistence of an invalid state.
Synchronize(state MutableState, ctx context.Context) (<-chan *Output, <-chan error)
// FetchDetails gets details for a given timeline item. Any changes to the TimelineItem
// (including the Meta value) will be persisted.
FetchDetails(item *TimelineItem, ctx context.Context) (interface{}, error)
}
======Edit #2=======
This is the original Synchronize method. I removed some details in my testing to try and simplify the problem. By removing a go func call, I believe I introduced a new problem which could be confusing things.
Here is the original Synchronize method. There are some things with Wait Groups and a function array containing a single function because this method will eventually be synchronizing multiple functions.
func (g *github) Synchronize(state implementation.MutableState, ctx context.Context) (<-chan *implementation.Output, <-chan error) {
wg := sync.WaitGroup{}
output := make(chan *implementation.Output)
errors := make(chan error, 1) // buffer allows preflight errors
// Close output channels once we're done
defer func() {
go func() {
wg.Wait()
close(errors)
close(output)
}()
}()
// Perform fetch functions in separate routines
funcs := []synchronizeFunc{
g.fetchUser,
}
for _, f := range funcs {
wg.Add(1)
go func(f synchronizeFunc) {
defer wg.Done()
if err := f(state, output, ctx); err != nil {
errors <- err
}
}(f)
}
glog.Info("after go sync...")
return output, errors
}
I think the two problems are in
output <- &implementation.Output{Data: user}
the channel does not have a buffer. It will block until some other goroutine reads from it. But in your code is the same goroutine so it will block.
and second:
// Close output channels once we're done
defer func() {
go func() {
// wg.Wait()
close(errors)
close(output)
}()
}()
you launch a go routine when the routine exits. The goroutine is scheduled, the function returns but it never calls the goroutine.
I would suggest to unify all that logic in one:
func (g *github) Synchronize(state implementation.MutableState, ctx context.Context) (<-chan *implementation.Output, <-chan error) {
output := make(chan *implementation.Output)
errors := make(chan error, 1) // buffer allows preflight errors
go func() {
defer close(output)
defer close(errors)
err := g.fetchUser(state, output, ctx)
if err != nil {
errors <- err
}
}()
return output, errors
}

Query WMI from Go

I would like to run WMI queries from Go. There are ways to call DLL functions from Go. My understanding is that there must be some DLL somewhere which, with the correct call, will return some data I can parse and use. I'd prefer to avoid calling into C or C++, especially since I would guess those are wrappers over the Windows API itself.
I've examined the output of dumpbin.exe /exports c:\windows\system32\wmi.dll, and the following entry looks promising:
WmiQueryAllDataA (forwarded to wmiclnt.WmiQueryAllDataA)
However I'm not sure what to do from here. What arguments does this function take? What does it return? Searching for WmiQueryAllDataA is not helpful. And that name only appears in a comment of c:\program files (x86)\windows kits\8.1\include\shared\wmistr.h, but with no function signature.
Are there better methods? Is there another DLL? Am I missing something? Should I just use a C wrapper?
Running a WMI query in Linqpad with .NET Reflector shows the use of WmiNetUtilsHelper:ExecQueryWmi (and a _f version), but neither have a viewable implementation.
Update: use the github.com/StackExchange/wmi package which uses the solution in the accepted answer.
Welcome to the wonderful world of COM, Object Oriented Programming in C from when C++ was "a young upstart".
On github mattn has thrown together a little wrapper in Go, which I used to throw together a quick example program. "This repository was created for experimentation and should be considered unstable." instills all sorts of confidence.
I'm leaving out a lot of error checking. Trust me when I say, you'll want to add it back.
package main
import (
"github.com/mattn/go-ole"
"github.com/mattn/go-ole/oleutil"
)
func main() {
// init COM, oh yeah
ole.CoInitialize(0)
defer ole.CoUninitialize()
unknown, _ := oleutil.CreateObject("WbemScripting.SWbemLocator")
defer unknown.Release()
wmi, _ := unknown.QueryInterface(ole.IID_IDispatch)
defer wmi.Release()
// service is a SWbemServices
serviceRaw, _ := oleutil.CallMethod(wmi, "ConnectServer")
service := serviceRaw.ToIDispatch()
defer service.Release()
// result is a SWBemObjectSet
resultRaw, _ := oleutil.CallMethod(service, "ExecQuery", "SELECT * FROM Win32_Process")
result := resultRaw.ToIDispatch()
defer result.Release()
countVar, _ := oleutil.GetProperty(result, "Count")
count := int(countVar.Val)
for i :=0; i < count; i++ {
// item is a SWbemObject, but really a Win32_Process
itemRaw, _ := oleutil.CallMethod(result, "ItemIndex", i)
item := itemRaw.ToIDispatch()
defer item.Release()
asString, _ := oleutil.GetProperty(item, "Name")
println(asString.ToString())
}
}
The real meat is the call to ExecQuery, I happen to grab Win32_Process from the available classes because it's easy to understand and print.
On my machine, this prints:
System Idle Process
System
smss.exe
csrss.exe
wininit.exe
services.exe
lsass.exe
svchost.exe
svchost.exe
atiesrxx.exe
svchost.exe
svchost.exe
svchost.exe
svchost.exe
svchost.exe
spoolsv.exe
svchost.exe
AppleOSSMgr.exe
AppleTimeSrv.exe
... and so on
go.exe
main.exe
I'm not running it elevated or with UAC disabled, but some WMI providers are gonna require a privileged user.
I'm also not 100% that this won't leak a little, you'll want to dig into that. COM objects are reference counted, so defer should be a pretty good fit there (provided the method isn't crazy long running) but go-ole may have some magic inside I didn't notice.
I'm commenting over a year later, but there is a solution here on github (and posted below for posterity).
// +build windows
/*
Package wmi provides a WQL interface for WMI on Windows.
Example code to print names of running processes:
type Win32_Process struct {
Name string
}
func main() {
var dst []Win32_Process
q := wmi.CreateQuery(&dst, "")
err := wmi.Query(q, &dst)
if err != nil {
log.Fatal(err)
}
for i, v := range dst {
println(i, v.Name)
}
}
*/
package wmi
import (
"bytes"
"errors"
"fmt"
"log"
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/mattn/go-ole"
"github.com/mattn/go-ole/oleutil"
)
var l = log.New(os.Stdout, "", log.LstdFlags)
var (
ErrInvalidEntityType = errors.New("wmi: invalid entity type")
lock sync.Mutex
)
// QueryNamespace invokes Query with the given namespace on the local machine.
func QueryNamespace(query string, dst interface{}, namespace string) error {
return Query(query, dst, nil, namespace)
}
// Query runs the WQL query and appends the values to dst.
//
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
// the query must have the same name in dst. Supported types are all signed and
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
// Array types are not supported.
//
// By default, the local machine and default namespace are used. These can be
// changed using connectServerArgs. See
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || dv.IsNil() {
return ErrInvalidEntityType
}
dv = dv.Elem()
mat, elemType := checkMultiArg(dv)
if mat == multiArgTypeInvalid {
return ErrInvalidEntityType
}
lock.Lock()
defer lock.Unlock()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
if err != nil {
oleerr := err.(*ole.OleError)
// S_FALSE = 0x00000001 // CoInitializeEx was already called on this thread
if oleerr.Code() != ole.S_OK && oleerr.Code() != 0x00000001 {
return err
}
} else {
// Only invoke CoUninitialize if the thread was not initizlied before.
// This will allow other go packages based on go-ole play along
// with this library.
defer ole.CoUninitialize()
}
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
return err
}
defer unknown.Release()
wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
return err
}
defer wmi.Release()
// service is a SWbemServices
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
if err != nil {
return err
}
service := serviceRaw.ToIDispatch()
defer serviceRaw.Clear()
// result is a SWBemObjectSet
resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
if err != nil {
return err
}
result := resultRaw.ToIDispatch()
defer resultRaw.Clear()
count, err := oleInt64(result, "Count")
if err != nil {
return err
}
// Initialize a slice with Count capacity
dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
var errFieldMismatch error
for i := int64(0); i < count; i++ {
err := func() error {
// item is a SWbemObject, but really a Win32_Process
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
if err != nil {
return err
}
item := itemRaw.ToIDispatch()
defer itemRaw.Clear()
ev := reflect.New(elemType)
if err = loadEntity(ev.Interface(), item); err != nil {
if _, ok := err.(*ErrFieldMismatch); ok {
// We continue loading entities even in the face of field mismatch errors.
// If we encounter any other error, that other error is returned. Otherwise,
// an ErrFieldMismatch is returned.
errFieldMismatch = err
} else {
return err
}
}
if mat != multiArgTypeStructPtr {
ev = ev.Elem()
}
dv.Set(reflect.Append(dv, ev))
return nil
}()
if err != nil {
return err
}
}
return errFieldMismatch
}
// ErrFieldMismatch is returned when a field is to be loaded into a different
// type than the one it was stored from, or when a field is missing or
// unexported in the destination struct.
// StructType is the type of the struct pointed to by the destination argument.
type ErrFieldMismatch struct {
StructType reflect.Type
FieldName string
Reason string
}
func (e *ErrFieldMismatch) Error() string {
return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
e.FieldName, e.StructType, e.Reason)
}
var timeType = reflect.TypeOf(time.Time{})
// loadEntity loads a SWbemObject into a struct pointer.
func loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
v := reflect.ValueOf(dst).Elem()
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
isPtr := f.Kind() == reflect.Ptr
if isPtr {
ptr := reflect.New(f.Type().Elem())
f.Set(ptr)
f = f.Elem()
}
n := v.Type().Field(i).Name
if !f.CanSet() {
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "CanSet() is false",
}
}
prop, err := oleutil.GetProperty(src, n)
if err != nil {
errFieldMismatch = &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "no such struct field",
}
continue
}
defer prop.Clear()
switch val := prop.Value().(type) {
case int, int64:
var v int64
switch val := val.(type) {
case int:
v = int64(val)
case int64:
v = val
default:
panic("unexpected type")
}
switch f.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
f.SetInt(v)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
f.SetUint(uint64(v))
default:
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "not an integer class",
}
}
case string:
iv, err := strconv.ParseInt(val, 10, 64)
switch f.Kind() {
case reflect.String:
f.SetString(val)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if err != nil {
return err
}
f.SetInt(iv)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if err != nil {
return err
}
f.SetUint(uint64(iv))
case reflect.Struct:
switch f.Type() {
case timeType:
if len(val) == 25 {
mins, err := strconv.Atoi(val[22:])
if err != nil {
return err
}
val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
}
t, err := time.Parse("20060102150405.000000-0700", val)
if err != nil {
return err
}
f.Set(reflect.ValueOf(t))
}
}
case bool:
switch f.Kind() {
case reflect.Bool:
f.SetBool(val)
default:
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "not a bool",
}
}
default:
typeof := reflect.TypeOf(val)
if isPtr && typeof == nil {
break
}
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: fmt.Sprintf("unsupported type (%T)", val),
}
}
}
return errFieldMismatch
}
type multiArgType int
const (
multiArgTypeInvalid multiArgType = iota
multiArgTypeStruct
multiArgTypeStructPtr
)
// checkMultiArg checks that v has type []S, []*S for some struct type S.
//
// It returns what category the slice's elements are, and the reflect.Type
// that represents S.
func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
if v.Kind() != reflect.Slice {
return multiArgTypeInvalid, nil
}
elemType = v.Type().Elem()
switch elemType.Kind() {
case reflect.Struct:
return multiArgTypeStruct, elemType
case reflect.Ptr:
elemType = elemType.Elem()
if elemType.Kind() == reflect.Struct {
return multiArgTypeStructPtr, elemType
}
}
return multiArgTypeInvalid, nil
}
func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
v, err := oleutil.GetProperty(item, prop)
if err != nil {
return 0, err
}
defer v.Clear()
i := int64(v.Val)
return i, nil
}
// CreateQuery returns a WQL query string that queries all columns of src. where
// is an optional string that is appended to the query, to be used with WHERE
// clauses. In such a case, the "WHERE" string should appear at the beginning.
func CreateQuery(src interface{}, where string) string {
var b bytes.Buffer
b.WriteString("SELECT ")
s := reflect.Indirect(reflect.ValueOf(src))
t := s.Type()
if s.Kind() == reflect.Slice {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return ""
}
var fields []string
for i := 0; i < t.NumField(); i++ {
fields = append(fields, t.Field(i).Name)
}
b.WriteString(strings.Join(fields, ", "))
b.WriteString(" FROM ")
b.WriteString(t.Name())
b.WriteString(" " + where)
return b.String()
}
To access the winmgmts object or a namespace (which is the same), you can use the code below. Basically, you need to specify the namespace as parameter, which is not documented properly in go-ole.
In the code below, you can also see how to access a class within this namespace and execute a method.
package main
import (
"log"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
func main() {
ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
defer ole.CoUninitialize()
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
log.Panic(err)
}
defer unknown.Release()
wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
log.Panic(err)
}
defer wmi.Release()
// Connect to namespace
// root/PanasonicPC = winmgmts:\\.\root\PanasonicPC
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", nil, "root/PanasonicPC")
if err != nil {
log.Panic(err)
}
service := serviceRaw.ToIDispatch()
defer serviceRaw.Clear()
// Get class
setBiosRaw, err := oleutil.CallMethod(service, "Get", "SetBIOS4Conf")
if err != nil {
log.Panic(err)
}
setBios := setBiosRaw.ToIDispatch()
defer setBiosRaw.Clear()
// Run method
resultRaw, err := oleutil.CallMethod(setBios, "AccessAuthorization", "letmein")
resultVal := resultRaw.Value().(int32)
log.Println("Return Code:", resultVal)
}
import(
"os/exec"
)
​func​ (​lcu​ ​*​LCU​) ​GrabToken​() {
​        ​cmd​ ​:=​ ​exec​.​Command​(​"powershell"​, ​"$cmdline = Get-WmiObject -Class Win32_Process"​)
​        ​
​        ​out​, ​err​ ​:=​ ​cmd​.​CombinedOutput​()
​        ​if​ ​err​ ​!=​ ​nil​ {
​                ​fmt​.​Println​(​err​)
​        }
​        ​outstr​ ​:=​ ​string(out)
​}