Custom error handling in Go - error-handling

I've been reading this blog post but I'm still not convinced I know exactly what to do to get custom errors that I can return from my functions and handle outside of them.
This is what I'm currently doing:
func doSomething() int {
x := 0
// Do something with x.
...
if somethingBadHappened {
return -1
}
if somethingElseBadHappened {
return -2
}
return x
}
This is what I'd like to be doing:
func doSomething() int, ? {
...
if somethingBadHappened {
return ?, err
}
if somethingElseBadHappened {
return ?, err2
}
return x, nil
}
But I'm not exactly sure how, and what to replace those question marks with.

You don't really need to return an int if you don't want to. You can do something like:
func doSomething() error {
...
if somethingBadHappened {
return errors.New("something bad happened")
}
if somethingElseBadHappened {
return errors.New("something else bad happened")
}
return nil
}
or if you want to return ints
func doSomething() (int, error) {
...
if somethingBadHappened {
return -1, errors.New("something bad happened")
}
if somethingElseBadHappened {
return -2, errors.New("something else bad happened")
}
return x, nil
}
be sure to import "errors" at the top.
If you want to test whether you got an error, you can do
x, err := doSomething()
if err != nil {
log.Println(err)
}

I would turn
func doSomething() int, ? {
...
if somethingBadHappened {
return ?, err
}
if somethingElseBadHappened {
return ?, err2
}
return x, nil
}
into
func doSomething() (r int, err error) {
...
if somethingBadHappened {
err = err1 // Whatever satisfies the `error` interface
return
}
if somethingElseBadHappened {
err = err2 // dtto
return
}
return x, nil
}
IOW, at the call site it is idiomatic* to ignore, never use or rely on, any other returned value if err != nil, so just don't care if r above has had been assigned some intermediate value or not.
(*) In the first approximatiom, i.e. if not stated otherwise. E.g. an io.Reader explicitly declares it can return both err == io.EOF and valid data at the same time:
When Read encounters an error or end-of-file condition after successfully reading n > 0 bytes, it returns the number of bytes read. It may return the (non-nil) error from the same call or return the error (and n == 0) from a subsequent call. An instance of this general case is that a Reader returning a non-zero number of bytes at the end of the input stream may return either err == EOF or err == nil. The next Read should return 0, EOF regardless.

Related

Testing Emitter func with channel return value in go

I am having a hard time getting my test for my emitter function which passes results through a channel for a data pipeline. This function will be triggered periodically and will pull records from the database. I compiled an stripped done version for this question the real code would more complex but would follow the same pattern. For testing I mocked the access to the database because I want to test the behavoir of the Emitter function.
I guess code is more than words:
This is the method I want to test:
//EmittRecord pull record from database
func EmittRecord(svc Service, count int) <-chan *Result {
out := make(chan *Result)
go func() {
defer close(out)
for i := 0; i < count; i++ {
r, err := svc.Next()
if err != nil {
out <- &Result{Error: err}
continue
}
out <- &Result{Payload: &Payload{
Field1: r.Field1,
Field2: r.Field2,
}, Error: nil}
}
}()
return out
}
I have a couple of types with an interface:
//Record is a Record from db
type Record struct {
Field1 string
Field2 string
}
//Payload is a record for the data pipeline
type Payload struct {
Field1 string
Field2 string
}
//Result is a type for the data pipeline
type Result struct {
Payload *Payload
Error error
}
//Service is an abstraction to access the database
type Service interface {
Next() (*Record, error)
}
This is my service Mock for testing:
//MockService is a struct to support testing for mocking the database
type MockService struct {
NextMock func() (*Record, error)
}
//Next is an Implementation of the Service interface for the mock
func (m *MockService) Next() (*Record, error) {
if m.NextMock != nil {
return m.NextMock()
}
panic("Please set NextMock!")
}
And finally this is my test method which does not work. It does not hit the done case and das not hit the 1*time.Second timeout case either ... the test just times out. I guess I am missing something here.
func TestEmitter(t *testing.T) {
tt := []struct {
name string
svc runner.Service
expectedResult runner.Result
}{
{name: "Database returns error",
svc: &runner.MockService{
NextMock: func() (*runner.Record, error) {
return nil, fmt.Errorf("YIKES")
},
},
expectedResult: runner.Result{Payload: nil, Error: fmt.Errorf("RRRR")},
},
{name: "Database returns record",
svc: &runner.MockService{
NextMock: func() (*runner.Record, error) {
return &runner.Record{
Field1: "hello",
Field2: "world",
}, nil
},
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
done := make(chan bool)
defer close(done)
var output <-chan *runner.Result
go func() {
output = runner.EmittRecord(tc.svc, 1)
done <- true
}()
found := <-output
<-done
select {
case <-done:
case <-time.After(1 * time.Second):
panic("timeout")
}
if found.Error.Error() != tc.expectedResult.Error.Error() {
t.Errorf("FAIL: %s, expected: %s; but got %s", tc.name, tc.expectedResult.Error.Error(), found.Error.Error())
} else if reflect.DeepEqual(found.Payload, tc.expectedResult.Payload) {
t.Errorf("FAIL: %s, expected: %+v; got %+v", tc.name, tc.expectedResult.Payload, found.Payload)
}
})
}
}
It would be great, if someone could give me an advice what I missing here and maybe some input how to verify the count of the EmittRecord function right now it is only set to 1
Thanks in advance
//Edited: the expectedResult as per Comment by #Lansana
Are you sure you have your expected results in the tests set to the proper value?
In the first slice in the test, you expect a fmt.Errorf("RRRR"), yet the mock returns a fmt.Errorf("YIKES").
And then later in the actual test conditionals, you do this:
if found.Error.Error() != "Hello" {
t.Errorf("FAIL: %s, expected: %s; but got %s", tc.name, tc.expectedResult.Error.Error(), found.Error.Error())
}
You are checking "Hello". Shouldn't you be checking if it's an error with the message "YIKES"?
I think your logic is good, but your test is just not properly written. Check my Go Playground example here and run the code. You will see there is no output or panics when you run it. This is because the code passes my test conditions in main.
You are adding more complexity to your test by more channels, and if those extra channels are invalid then you may have some false positives that make you think your business logic is bad. In this case, it actually seems to be working as it should.
Here is the highlight of the code from my playground example . (the part that tests your logic):
func main() {
svc1 := &MockService{
NextMock: func() (*Record, error) {
return nil, errors.New("foo")
},
}
for item := range EmittRecord(svc1, 5) {
if item.Payload != nil {
panic("item.Payload should be nil")
}
if item.Error == nil {
panic("item.Error should be an error")
}
}
svc2 := &MockService{
NextMock: func() (*Record, error) {
return &Record{Field1: "Hello ", Field2: "World"}, nil
},
}
for item := range EmittRecord(svc2, 5) {
if item.Payload == nil {
panic("item.Payload should have a value")
}
if item.Payload.Field1 + item.Payload.Field2 != "Hello World" {
panic("item.Payload.Field1 and item.Payload.Field2 are invalid!")
}
if item.Error != nil {
panic("item.Error should be nil")
}
}
}
The output from the above code is nothing. No panics. Thus, it succeeded.
Try simplifying your test to a working state, and then add more complexity from there. :)

Mocking functions in Golang to test my http routes

I'm totally confused figuring out how I can mock a function, without using any additional packages like golang/mock. I'm just trying to learn how to do so but can't find many decent online resources.
Essentially, I followed this excellent article that explains how to use an interface to mock things.
As so, I've re-written the function I wanted to test. The function just inserts some data into datastore. My tests for that are ok - I can mock the function directly.
The issue I'm having is mocking it 'within' an http route I'm trying to test. Am using the Gin framework.
My router (simplified) looks like this:
func SetupRouter() *gin.Engine {
r := gin.Default()
r.Use(gin.Logger())
r.Use(gin.Recovery())
v1 := r.Group("v1")
v1.PATCH("operations/:id", controllers.UpdateOperation)
}
Which calls the UpdateOperation function:
func UpdateOperation(c *gin.Context) {
id := c.Param("id")
r := m.Response{}
str := m.OperationInfoer{}
err := m.FindAndCompleteOperation(str, id, r.Report)
if err == nil {
c.JSON(200, gin.H{
"message": "Operation completed",
})
}
}
So, I need to mock the FindAndCompleteOperation() function.
The main (simplified) functions looks like this:
func (oi OperationInfoer) FindAndCompleteOp(id string, report Report) error {
ctx := context.Background()
q := datastore.NewQuery("Operation").
Filter("Unique_Id =", id).
Limit(1)
var ops []Operation
if ts, err := db.Datastore.GetAll(ctx, q, &ops); err == nil {
{
if len(ops) > 0 {
ops[0].Id = ts[0].ID()
ops[0].Complete = true
// Do stuff
_, err := db.Datastore.Put(ctx, key, &o)
if err == nil {
log.Print("OPERATION COMPLETED")
}
}
}
}
err := errors.New("Not found")
return err
}
func FindAndCompleteOperation(ri OperationInfoer, id string, report Report) error {
return ri.FindAndCompleteOp(id, report)
}
type OperationInfoer struct{}
To test the route that updates the operation, I have something like so:
FIt("Return 200, updates operation", func() {
testRouter := SetupRouter()
param := make(url.Values)
param["access_token"] = []string{public_token}
report := m.Report{}
report.Success = true
report.Output = "my output"
jsonStr, _ := json.Marshal(report)
req, _ := http.NewRequest("PATCH", "/v1/operations/123?"+param.Encode(), bytes.NewBuffer(jsonStr))
resp := httptest.NewRecorder()
testRouter.ServeHTTP(resp, req)
Expect(resp.Code).To(Equal(200))
o := FakeResponse{}
json.NewDecoder(resp.Body).Decode(&o)
Expect(o.Message).To(Equal("Operation completed"))
})
Originally, I tried to cheat a bit and just tried something like this:
m.FindAndCompleteOperation = func(string, m.Report) error {
return nil
}
But that affects all the other tests etc.
I'm hoping someone can explain simply what the best way to mock the FindAndCompleteOperation function so I can test the routes, without relying on datastore etc.
I have another relevant, more informative answer to a similar question here, but here's an answer for your specific scenario:
Update your SetupRouter() function to take a function that can either be the real FindAndCompleteOperation function or a stub function:
Playground
package main
import "github.com/gin-gonic/gin"
// m.Response.Report
type Report struct {
// ...
}
// m.OperationInfoer
type OperationInfoer struct {
// ...
}
type findAndComplete func(s OperationInfoer, id string, report Report) error
func FindAndCompleteOperation(OperationInfoer, string, Report) error {
// ...
return nil
}
func SetupRouter(f findAndComplete) *gin.Engine {
r := gin.Default()
r.Group("v1").PATCH("/:id", func(c *gin.Context) {
if f(OperationInfoer{}, c.Param("id"), Report{}) == nil {
c.JSON(200, gin.H{"message": "Operation completed"})
}
})
return r
}
func main() {
r := SetupRouter(FindAndCompleteOperation)
if err := r.Run(":8080"); err != nil {
panic(err)
}
}
Test/mocking example
package main
import (
"encoding/json"
"net/http/httptest"
"strings"
"testing"
)
func TestUpdateRoute(t *testing.T) {
// build findAndComplete stub
var callCount int
var lastInfoer OperationInfoer
var lastID string
var lastReport Report
stub := func(s OperationInfoer, id string, report Report) error {
callCount++
lastInfoer = s
lastID = id
lastReport = report
return nil // or `fmt.Errorf("Err msg")` if you want to test fault path
}
// invoke endpoint
w := httptest.NewRecorder()
r := httptest.NewRequest(
"PATCH",
"/v1/id_value",
strings.NewReader(""),
)
SetupRouter(stub).ServeHTTP(w, r)
// check that the stub was invoked correctly
if callCount != 1 {
t.Fatal("Wanted 1 call; got", callCount)
}
if lastInfoer != (OperationInfoer{}) {
t.Fatalf("Wanted %v; got %v", OperationInfoer{}, lastInfoer)
}
if lastID != "id_value" {
t.Fatalf("Wanted 'id_value'; got '%s'", lastID)
}
if lastReport != (Report{}) {
t.Fatalf("Wanted %v; got %v", Report{}, lastReport)
}
// check that the correct response was returned
if w.Code != 200 {
t.Fatal("Wanted HTTP 200; got HTTP", w.Code)
}
var body map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatal("Unexpected error:", err)
}
if body["message"] != "Operation completed" {
t.Fatal("Wanted 'Operation completed'; got %s", body["message"])
}
}
You can't mock if you use globals that can't be mocked in an handler. Either your globals are mockable (i.e. declared as variables of interface type) or you need to use dependency injection.
func (oi OperationInfoer) FindAndCompleteOp(id string, report Report) error {...}
looks like a method of a struct, so you should be able to inject that struct into an handler, at the very least.
type OperationInfoer interface {
FindAndCompleteOp(id string, report Report) error
}
type ConcreteOperationInfoer struct { /* actual implementation */ }
func UpdateOperation(oi OperationInfoer) func(c *gin.Context) {
return func (c *gin.Context){
// the code
}
}
then mocking becomes a breeze in your tests :
UpdateOperation(mockOperationInfoer)(ginContext)
You can use a struct instead of closures
type UpdateOperationHandler struct {
Oi OperationInfoer
}
func (h UpdateOperationHandler ) UpdateOperation (c *gin.Context) {
h.Oi.FindAndCompleteOp(/* code */ )
}

Portable way to detect different kinds of network error

I would like to identify what kind of error occurred in the network level. The only way I found was checking the error messages with a regular expression, but now I discovered that this messages can be in different languages (depending on the OS configuration), making it difficult to detect by regular expressions. Is there a better way to do it?
package main
import (
"github.com/miekg/dns"
"net"
"regexp"
)
func main() {
var c dns.Client
m := new(dns.Msg)
m.SetQuestion("3com.br.", dns.TypeSOA)
_, _, err := c.Exchange(m, "ns1.3com.com.:53")
checkErr(err)
m.SetQuestion("example.com.", dns.TypeSOA)
_, _, err = c.Exchange(m, "idontexist.br.:53")
checkErr(err)
m.SetQuestion("acasadocartaocuritiba.blog.br.", dns.TypeSOA)
_, _, err = c.Exchange(m, "ns7.storedns22.in.:53")
checkErr(err)
}
func checkErr(err error) {
if err == nil {
println("Ok")
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
println("Timeout")
} else if match, _ := regexp.MatchString(".*lookup.*", err.Error()); match {
println("Unknown host")
} else if match, _ := regexp.MatchString(".*connection refused.*", err.Error()); match {
println("Connection refused")
} else {
println("Other error")
}
}
Result:
$ go run neterrors.go
Timeout
Unknown host
Connection refused
I discover the problem when testing the system in a Windows OS with Portuguese as default language.
[EDIT]
I found a way to do it using the OpError. Here is the checkErr function again with the new approach. If someone has a better solution I will be very glad to known it!
func checkErr(err error) {
if err == nil {
println("Ok")
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
println("Timeout")
} else if opError, ok := err.(*net.OpError); ok {
if opError.Op == "dial" {
println("Unknown host")
} else if opError.Op == "read" {
println("Connection refused")
}
}
}
[EDIT2]
Updated after seong answer.
func checkErr(err error) {
if err == nil {
println("Ok")
return
} else if netError, ok := err.(net.Error); ok && netError.Timeout() {
println("Timeout")
return
}
switch t := err.(type) {
case *net.OpError:
if t.Op == "dial" {
println("Unknown host")
} else if t.Op == "read" {
println("Connection refused")
}
case syscall.Errno:
if t == syscall.ECONNREFUSED {
println("Connection refused")
}
}
}
The net package works closely with your OS. For OS errors, the Go std library uses the pkg syscall. Have a look here: http://golang.org/pkg/syscall/
The net package can also return syscall.Errno type errors.
For a simpler code in you checkErr function, you could consider using a type switch (http://golang.org/doc/effective_go.html#type_switch).

io.WriteSeeker and io.ReadSeeker from []byte or file

I have a method called "DoSomething". DoSomething will take binary source data perform an operation on it, and write out binary data. DoSomething needs to be generic enough to handle either a []byte array or a file handle for both the source and destination. To accomplish this, I have attempted to declare the method like this:
func DoSomething(source *io.ReadSeeker, destination *io.WriteSeeker)
I have implemented the ReadSeeker and WriteSeeker for working with buffers, using my own custom, required methods (if there is a way to automatically accomplish this, I'd love to hear about it as well).
Unfortunately, I can't seem to figure out how to create either an io.ReadSeeker or io.WriteSeeker from a file handle. I'm fairly sure there must be some pre-cooked way of handling this without having to manually implement them. Is this possible?
A file already implements both of those. You can do something like this:
package main
import (
"fmt"
"io"
"os"
)
func main() {
f, err := os.Open("test.txt")
if err != nil {
fmt.Println(err)
}
defer f.Close()
f2, err := os.Create("test2.txt")
if err != nil {
fmt.Println(err)
}
defer f2.Close()
DoSomething(f, f2)
}
func DoSomething(source io.ReadSeeker, destination io.WriteSeeker) {
io.Copy(destination, source)
}
Also, you don't need to pass pointers to interfaces, which makes it easier to deal with them.
For anyone else who needs to accomplish something like this, here's what I ended up with. It isn't complete, but it is close enough for what I needed:
package filebuffer
import (
"bytes"
"errors"
)
type FileBuffer struct {
Buffer bytes.Buffer
Index int64
}
func NewFileBuffer() FileBuffer {
return FileBuffer{}
}
func (fbuffer *FileBuffer) Bytes() []byte {
return fbuffer.Buffer.Bytes()
}
func (fbuffer *FileBuffer) Read(p []byte) (int, error) {
n, err := bytes.NewBuffer(fbuffer.Buffer.Bytes()[fbuffer.Index:]).Read(p)
if err == nil {
if fbuffer.Index+int64(len(p)) < int64(fbuffer.Buffer.Len()) {
fbuffer.Index += int64(len(p))
} else {
fbuffer.Index = int64(fbuffer.Buffer.Len())
}
}
return n, err
}
func (fbuffer *FileBuffer) Write(p []byte) (int, error) {
n, err := fbuffer.Buffer.Write(p)
if err == nil {
fbuffer.Index = int64(fbuffer.Buffer.Len())
}
return n, err
}
func (fbuffer *FileBuffer) Seek(offset int64, whence int) (int64, error) {
var err error
var Index int64 = 0
switch whence {
case 0:
if offset >= int64(fbuffer.Buffer.Len()) || offset < 0 {
err = errors.New("Invalid Offset.")
} else {
fbuffer.Index = offset
Index = offset
}
default:
err = errors.New("Unsupported Seek Method.")
}
return Index, err
}
You then use it like this:
destination := filebuffer.NewFileBuffer()
source, err := os.Open(pathString)
if err != nil {
return nil, err
}
defer source.Close()
if _, err := encrypter.Decrypt(source, &destination, password); err != nil {
return nil, err
}
For anyone who needs it, here's an implementation of io.ReadWriteSeeker, that supports all I/O operations:
import (
"errors"
"fmt"
"io"
)
// Implements io.ReadWriteSeeker for testing purposes.
type FileBuffer struct {
buffer []byte
offset int64
}
// Creates new buffer that implements io.ReadWriteSeeker for testing purposes.
func NewFileBuffer(initial []byte) FileBuffer {
if initial == nil {
initial = make([]byte, 0, 100)
}
return FileBuffer{
buffer: initial,
offset: 0,
}
}
func (fb *FileBuffer) Bytes() []byte {
return fb.buffer
}
func (fb *FileBuffer) Len() int {
return len(fb.buffer)
}
func (fb *FileBuffer) Read(b []byte) (int, error) {
available := len(fb.buffer) - int(fb.offset)
if available == 0 {
return 0, io.EOF
}
size := len(b)
if size > available {
size = available
}
copy(b, fb.buffer[fb.offset:fb.offset+int64(size)])
fb.offset += int64(size)
return size, nil
}
func (fb *FileBuffer) Write(b []byte) (int, error) {
copied := copy(fb.buffer[fb.offset:], b)
if copied < len(b) {
fb.buffer = append(fb.buffer, b[copied:]...)
}
fb.offset += int64(len(b))
return len(b), nil
}
func (fb *FileBuffer) Seek(offset int64, whence int) (int64, error) {
var newOffset int64
switch whence {
case io.SeekStart:
newOffset = offset
case io.SeekCurrent:
newOffset = fb.offset + offset
case io.SeekEnd:
newOffset = int64(len(fb.buffer)) + offset
default:
return 0, errors.New("Unknown Seek Method")
}
if newOffset > int64(len(fb.buffer)) || newOffset < 0 {
return 0, fmt.Errorf("Invalid Offset %d", offset)
}
fb.offset = newOffset
return newOffset, nil
}

Query WMI from Go

I would like to run WMI queries from Go. There are ways to call DLL functions from Go. My understanding is that there must be some DLL somewhere which, with the correct call, will return some data I can parse and use. I'd prefer to avoid calling into C or C++, especially since I would guess those are wrappers over the Windows API itself.
I've examined the output of dumpbin.exe /exports c:\windows\system32\wmi.dll, and the following entry looks promising:
WmiQueryAllDataA (forwarded to wmiclnt.WmiQueryAllDataA)
However I'm not sure what to do from here. What arguments does this function take? What does it return? Searching for WmiQueryAllDataA is not helpful. And that name only appears in a comment of c:\program files (x86)\windows kits\8.1\include\shared\wmistr.h, but with no function signature.
Are there better methods? Is there another DLL? Am I missing something? Should I just use a C wrapper?
Running a WMI query in Linqpad with .NET Reflector shows the use of WmiNetUtilsHelper:ExecQueryWmi (and a _f version), but neither have a viewable implementation.
Update: use the github.com/StackExchange/wmi package which uses the solution in the accepted answer.
Welcome to the wonderful world of COM, Object Oriented Programming in C from when C++ was "a young upstart".
On github mattn has thrown together a little wrapper in Go, which I used to throw together a quick example program. "This repository was created for experimentation and should be considered unstable." instills all sorts of confidence.
I'm leaving out a lot of error checking. Trust me when I say, you'll want to add it back.
package main
import (
"github.com/mattn/go-ole"
"github.com/mattn/go-ole/oleutil"
)
func main() {
// init COM, oh yeah
ole.CoInitialize(0)
defer ole.CoUninitialize()
unknown, _ := oleutil.CreateObject("WbemScripting.SWbemLocator")
defer unknown.Release()
wmi, _ := unknown.QueryInterface(ole.IID_IDispatch)
defer wmi.Release()
// service is a SWbemServices
serviceRaw, _ := oleutil.CallMethod(wmi, "ConnectServer")
service := serviceRaw.ToIDispatch()
defer service.Release()
// result is a SWBemObjectSet
resultRaw, _ := oleutil.CallMethod(service, "ExecQuery", "SELECT * FROM Win32_Process")
result := resultRaw.ToIDispatch()
defer result.Release()
countVar, _ := oleutil.GetProperty(result, "Count")
count := int(countVar.Val)
for i :=0; i < count; i++ {
// item is a SWbemObject, but really a Win32_Process
itemRaw, _ := oleutil.CallMethod(result, "ItemIndex", i)
item := itemRaw.ToIDispatch()
defer item.Release()
asString, _ := oleutil.GetProperty(item, "Name")
println(asString.ToString())
}
}
The real meat is the call to ExecQuery, I happen to grab Win32_Process from the available classes because it's easy to understand and print.
On my machine, this prints:
System Idle Process
System
smss.exe
csrss.exe
wininit.exe
services.exe
lsass.exe
svchost.exe
svchost.exe
atiesrxx.exe
svchost.exe
svchost.exe
svchost.exe
svchost.exe
svchost.exe
spoolsv.exe
svchost.exe
AppleOSSMgr.exe
AppleTimeSrv.exe
... and so on
go.exe
main.exe
I'm not running it elevated or with UAC disabled, but some WMI providers are gonna require a privileged user.
I'm also not 100% that this won't leak a little, you'll want to dig into that. COM objects are reference counted, so defer should be a pretty good fit there (provided the method isn't crazy long running) but go-ole may have some magic inside I didn't notice.
I'm commenting over a year later, but there is a solution here on github (and posted below for posterity).
// +build windows
/*
Package wmi provides a WQL interface for WMI on Windows.
Example code to print names of running processes:
type Win32_Process struct {
Name string
}
func main() {
var dst []Win32_Process
q := wmi.CreateQuery(&dst, "")
err := wmi.Query(q, &dst)
if err != nil {
log.Fatal(err)
}
for i, v := range dst {
println(i, v.Name)
}
}
*/
package wmi
import (
"bytes"
"errors"
"fmt"
"log"
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/mattn/go-ole"
"github.com/mattn/go-ole/oleutil"
)
var l = log.New(os.Stdout, "", log.LstdFlags)
var (
ErrInvalidEntityType = errors.New("wmi: invalid entity type")
lock sync.Mutex
)
// QueryNamespace invokes Query with the given namespace on the local machine.
func QueryNamespace(query string, dst interface{}, namespace string) error {
return Query(query, dst, nil, namespace)
}
// Query runs the WQL query and appends the values to dst.
//
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
// the query must have the same name in dst. Supported types are all signed and
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
// Array types are not supported.
//
// By default, the local machine and default namespace are used. These can be
// changed using connectServerArgs. See
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || dv.IsNil() {
return ErrInvalidEntityType
}
dv = dv.Elem()
mat, elemType := checkMultiArg(dv)
if mat == multiArgTypeInvalid {
return ErrInvalidEntityType
}
lock.Lock()
defer lock.Unlock()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
if err != nil {
oleerr := err.(*ole.OleError)
// S_FALSE = 0x00000001 // CoInitializeEx was already called on this thread
if oleerr.Code() != ole.S_OK && oleerr.Code() != 0x00000001 {
return err
}
} else {
// Only invoke CoUninitialize if the thread was not initizlied before.
// This will allow other go packages based on go-ole play along
// with this library.
defer ole.CoUninitialize()
}
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
return err
}
defer unknown.Release()
wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
return err
}
defer wmi.Release()
// service is a SWbemServices
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
if err != nil {
return err
}
service := serviceRaw.ToIDispatch()
defer serviceRaw.Clear()
// result is a SWBemObjectSet
resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
if err != nil {
return err
}
result := resultRaw.ToIDispatch()
defer resultRaw.Clear()
count, err := oleInt64(result, "Count")
if err != nil {
return err
}
// Initialize a slice with Count capacity
dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
var errFieldMismatch error
for i := int64(0); i < count; i++ {
err := func() error {
// item is a SWbemObject, but really a Win32_Process
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
if err != nil {
return err
}
item := itemRaw.ToIDispatch()
defer itemRaw.Clear()
ev := reflect.New(elemType)
if err = loadEntity(ev.Interface(), item); err != nil {
if _, ok := err.(*ErrFieldMismatch); ok {
// We continue loading entities even in the face of field mismatch errors.
// If we encounter any other error, that other error is returned. Otherwise,
// an ErrFieldMismatch is returned.
errFieldMismatch = err
} else {
return err
}
}
if mat != multiArgTypeStructPtr {
ev = ev.Elem()
}
dv.Set(reflect.Append(dv, ev))
return nil
}()
if err != nil {
return err
}
}
return errFieldMismatch
}
// ErrFieldMismatch is returned when a field is to be loaded into a different
// type than the one it was stored from, or when a field is missing or
// unexported in the destination struct.
// StructType is the type of the struct pointed to by the destination argument.
type ErrFieldMismatch struct {
StructType reflect.Type
FieldName string
Reason string
}
func (e *ErrFieldMismatch) Error() string {
return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
e.FieldName, e.StructType, e.Reason)
}
var timeType = reflect.TypeOf(time.Time{})
// loadEntity loads a SWbemObject into a struct pointer.
func loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
v := reflect.ValueOf(dst).Elem()
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
isPtr := f.Kind() == reflect.Ptr
if isPtr {
ptr := reflect.New(f.Type().Elem())
f.Set(ptr)
f = f.Elem()
}
n := v.Type().Field(i).Name
if !f.CanSet() {
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "CanSet() is false",
}
}
prop, err := oleutil.GetProperty(src, n)
if err != nil {
errFieldMismatch = &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "no such struct field",
}
continue
}
defer prop.Clear()
switch val := prop.Value().(type) {
case int, int64:
var v int64
switch val := val.(type) {
case int:
v = int64(val)
case int64:
v = val
default:
panic("unexpected type")
}
switch f.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
f.SetInt(v)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
f.SetUint(uint64(v))
default:
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "not an integer class",
}
}
case string:
iv, err := strconv.ParseInt(val, 10, 64)
switch f.Kind() {
case reflect.String:
f.SetString(val)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if err != nil {
return err
}
f.SetInt(iv)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if err != nil {
return err
}
f.SetUint(uint64(iv))
case reflect.Struct:
switch f.Type() {
case timeType:
if len(val) == 25 {
mins, err := strconv.Atoi(val[22:])
if err != nil {
return err
}
val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
}
t, err := time.Parse("20060102150405.000000-0700", val)
if err != nil {
return err
}
f.Set(reflect.ValueOf(t))
}
}
case bool:
switch f.Kind() {
case reflect.Bool:
f.SetBool(val)
default:
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "not a bool",
}
}
default:
typeof := reflect.TypeOf(val)
if isPtr && typeof == nil {
break
}
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: fmt.Sprintf("unsupported type (%T)", val),
}
}
}
return errFieldMismatch
}
type multiArgType int
const (
multiArgTypeInvalid multiArgType = iota
multiArgTypeStruct
multiArgTypeStructPtr
)
// checkMultiArg checks that v has type []S, []*S for some struct type S.
//
// It returns what category the slice's elements are, and the reflect.Type
// that represents S.
func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
if v.Kind() != reflect.Slice {
return multiArgTypeInvalid, nil
}
elemType = v.Type().Elem()
switch elemType.Kind() {
case reflect.Struct:
return multiArgTypeStruct, elemType
case reflect.Ptr:
elemType = elemType.Elem()
if elemType.Kind() == reflect.Struct {
return multiArgTypeStructPtr, elemType
}
}
return multiArgTypeInvalid, nil
}
func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
v, err := oleutil.GetProperty(item, prop)
if err != nil {
return 0, err
}
defer v.Clear()
i := int64(v.Val)
return i, nil
}
// CreateQuery returns a WQL query string that queries all columns of src. where
// is an optional string that is appended to the query, to be used with WHERE
// clauses. In such a case, the "WHERE" string should appear at the beginning.
func CreateQuery(src interface{}, where string) string {
var b bytes.Buffer
b.WriteString("SELECT ")
s := reflect.Indirect(reflect.ValueOf(src))
t := s.Type()
if s.Kind() == reflect.Slice {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return ""
}
var fields []string
for i := 0; i < t.NumField(); i++ {
fields = append(fields, t.Field(i).Name)
}
b.WriteString(strings.Join(fields, ", "))
b.WriteString(" FROM ")
b.WriteString(t.Name())
b.WriteString(" " + where)
return b.String()
}
To access the winmgmts object or a namespace (which is the same), you can use the code below. Basically, you need to specify the namespace as parameter, which is not documented properly in go-ole.
In the code below, you can also see how to access a class within this namespace and execute a method.
package main
import (
"log"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
func main() {
ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
defer ole.CoUninitialize()
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
log.Panic(err)
}
defer unknown.Release()
wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
log.Panic(err)
}
defer wmi.Release()
// Connect to namespace
// root/PanasonicPC = winmgmts:\\.\root\PanasonicPC
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", nil, "root/PanasonicPC")
if err != nil {
log.Panic(err)
}
service := serviceRaw.ToIDispatch()
defer serviceRaw.Clear()
// Get class
setBiosRaw, err := oleutil.CallMethod(service, "Get", "SetBIOS4Conf")
if err != nil {
log.Panic(err)
}
setBios := setBiosRaw.ToIDispatch()
defer setBiosRaw.Clear()
// Run method
resultRaw, err := oleutil.CallMethod(setBios, "AccessAuthorization", "letmein")
resultVal := resultRaw.Value().(int32)
log.Println("Return Code:", resultVal)
}
import(
"os/exec"
)
​func​ (​lcu​ ​*​LCU​) ​GrabToken​() {
​        ​cmd​ ​:=​ ​exec​.​Command​(​"powershell"​, ​"$cmdline = Get-WmiObject -Class Win32_Process"​)
​        ​
​        ​out​, ​err​ ​:=​ ​cmd​.​CombinedOutput​()
​        ​if​ ​err​ ​!=​ ​nil​ {
​                ​fmt​.​Println​(​err​)
​        }
​        ​outstr​ ​:=​ ​string(out)
​}