How to determine name of database driver I'm using? - sql

In code which tries to be database agnostic, I would like to perform some database specific queries, so I need to know name of Database Driver in Go language:
db,err := sql.Open(dbstr, dbconnstr)
if err != nil {
log.Fatal(err)
}
errp := db.Ping()
if errp != nil {
log.Fatal(errp)
}
log.Printf("%s\n", db.Driver())
How I can determine name of database driver I'm using?

Give your database string in url format like postgres://postgres#localhost:5432/db_name?sslmode=disable.
And then find the database type you are using Parse function of url package. Based on the database type, run db specific queries.
func New(url string) (Driver, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
switch u.Scheme {
case "postgres":
d := &postgres.Driver{}
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "mysql":
d := &mysql.Driver{}
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "bash":
d := &bash.Driver{}
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "cassandra":
d := &cassandra.Driver{}
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "sqlite3":
d := &sqlite3.Driver{}
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
default:
return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme))
}
}

You should already know the name of the database driver because its represented by the parameter you identified with the dbstr variable.
db, err := sql.Open("postgres", "user= ... ")
if err != nil {
log.Fatal(err)
}
db.Driver() correctly returns the underlying driver in use, but you are formatting it as string (because of %s). If you change %s with %T you will see that it correctly prints out the type:
log.Printf("%T\n", db.Driver())
For example, if you use github.com/lib/pq, the output is *pq.drv. This is the same of using the reflect package:
log.Printf("%s\n", reflect.TypeOf(db.Driver()))
It may be impractical to use that value for performing conditional executions. Moreover, the Driver interface doesn't specify any way to get the specific driver information, except the Open() function.
If you have specific needs, you may want to either use the driver name passed when you open the connection, or create specific drivers that delegate to the original ones and handle your custom logic.

Related

How to intercept `rollback` in gorm?

I need to execute some things after all create executions fail.
It seems that callbacks can be satisfied, but there is a case that if it is an operation in a transaction, it may not actually be executed. I need to do it after rollback Treat accordingly. So the question is, how do I intercept rollback?
You can use manual transaction in a function like this.
func CreateAnimals(db *gorm.DB) error {
// Note the use of tx as the database handle once you are within a transaction
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if err := tx.Error; err != nil {
return err
}
if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
tx.Rollback()
return err
}
if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
If CreateAnimals fails then you can do your desired job.

SQL Next not advancing cursor

I have a function that I used to iterate over a result set from a query:
func readRows(rows *sql.Rows, translator func(*sql.Rows) error) error {
defer rows.Close()
// Iterate over each row in the rows and scan each; if an error occurs then return
for shouldScan := rows.Next(); shouldScan; {
if err := translator(rows); err != nil {
return err
}
}
// Check if the rows had an error; if they did then return them. Otherwise,
// close the rows and return an error if the close function fails
if err := rows.Err(); err != nil {
return err
}
return nil
}
The translator function is primarily responsible for calling Scan on the *sql.Rows object. An example of this is:
readRows(rows, func(scanner *sql.Rows) error {
var entry gopb.TestObject
// Embed the variables into a list that we can use to pull information out of the rows
scanned := []interface{}{...}
if err := scanner.Scan(scanned...); err != nil {
return err
}
entries = append(entries, &entry)
return nil
})
I wrote a unit test for this code:
// Create the SQL mock and the RDS reqeuster
db, mock, _ := sqlmock.New()
requester := Requester{conn: db}
defer db.Close()
// Create the rows we'll use for testing the query
rows := sqlmock.NewRows([]string{"id", "data"}).
AddRow(0, "data")
// Verify the command order for the transaction
mock.ExpectBegin()
mock.ExpectQuery(regexp.QuoteMeta("SELECT `id`, `data`, FROM `data`")).WillReturnRows(rows)
mock.ExpectRollback()
// Attempt to get the data
data, err := requester.GetData(context.TODO())
However, it appears that Next is being called infinitely. I'm not sure if this is an sqlmock issue or an issue with my code. Any help would be appreciated.

How can i tell the PATCH Method which field i want to update

I'm working on a simple REST API and I'm having troubles with the PATCH method. I don't know how can i tell the method and the query which fields i want to update(for example which fields are passed as JSON) in the database. Here is what i have so far.
func PatchServer(c echo.Context) error {
patchedServer := new(structs.Server)
requestID := c.Param("id")
if err := c.Bind(patchedServer); err != nil {
return err
}
sql := "UPDATE servers SET server_name = CASE WHEN ? IS NOT NULL THEN ? END WHERE id = ?"
stmt, err := db.Get().Prepare(sql)
if err != nil {
panic(err)
}
_, err2 := stmt.Exec(patchedServer.Name, patchedServer.Name, requestID)
if err2 != nil {
panic(err2)
}
fmt.Println(patchedServer.ID, patchedServer.Name, patchedServer.Components)
fmt.Println("Requested id: ", requestID)
return c.JSON(http.StatusOK, "Patched!")
}

Bulk insert copy sql table with golang

For the context, I'm new to go and I'm creating a program that can copy tables from Oracle to MySQL.
I use database/sql go package, so I assume it can be used for migrating any kind of database.
To simplify my question I'm coping on the same MySQL database table name world.city to world.city_copy2.
with my following code, I ended up with the same last values in all the rows in the table :-(
do I somehow need to read through all the values inside the loop? what is the efficient way to do that?
package main
import (
"database/sql"
"fmt"
"strings"
_ "github.com/go-sql-driver/mysql"
)
const (
user = "user"
pass = "testPass"
server = "localhost"
)
func main() {
fmt.Print("test")
conStr := fmt.Sprintf("%s:%s#tcp(%s)/world", user, pass, server)
db, err := sql.Open("mysql", conStr)
if err != nil {
panic(err.Error())
}
defer db.Close()
err = db.Ping()
if err != nil {
panic(err.Error())
}
rows, err := db.Query("SELECT * FROM city")
if err != nil {
panic(err.Error()) // proper error handling instead of panic in your app
}
columns, err := rows.Columns()
if err != nil {
panic(err.Error()) // proper error handling instead of panic in your app
}
// Make a slice for the values
values := make([]sql.RawBytes, len(columns))
// rows.Scan wants '[]interface{}' as an argument, so we must copy the
// references into such a slice
scanArgs := make([]interface{}, len(values))
for i := range values {
scanArgs[i] = &values[i]
}
// that string will be generated according to len of columns
placeHolders := "( ?, ?, ?, ?, ? )"
// slice will contain all the values at the end
bulkValues := []interface{}{}
valueStrings := make([]string, 0)
for rows.Next() {
// get RawBytes from data
err = rows.Scan(scanArgs...)
if err != nil {
panic(err.Error()) // proper error handling instead of panic in your app
}
valueStrings = append(valueStrings, placeHolders)
bulkValues = append(bulkValues, scanArgs...)
//
}
stmStr := fmt.Sprintf("INSERT INTO city_copy2 VALUES %s", strings.Join(valueStrings, ","))
_, err = db.Exec(stmStr, bulkValues...)
if err != nil {
panic(err.Error())
}
}
I have checked out the docs of the library, and it seems that the problem here is that bulkValues keeps the address of the pointer so when scanArgs changes, bulkValues also changes to latest value of that scanArgs.
You need to use the values variable to get the values like below:
func main() {
fmt.Print("test")
conStr := fmt.Sprintf("%s:%s#tcp(%s)/soverflow", user, pass, server)
db, err := sql.Open("mysql", conStr)
if err != nil {
panic(err.Error())
}
defer db.Close()
err = db.Ping()
if err != nil {
panic(err.Error())
}
rows, err := db.Query("SELECT * FROM city")
if err != nil {
panic(err.Error()) // proper error handling instead of panic in your app
}
columns, err := rows.Columns()
if err != nil {
panic(err.Error()) // proper error handling instead of panic in your app
}
// Make a slice for the values
values := make([]sql.RawBytes, len(columns))
// rows.Scan wants '[]interface{}' as an argument, so we must copy the
// references into such a slice
scanArgs := make([]interface{}, len(values))
for i := range values {
scanArgs[i] = &values[i]
}
// that string will be generated according to len of columns
placeHolders := "( ?, ?, ?, ?, ? )"
// slice will contain all the values at the end
bulkValues := []interface{}{}
valueStrings := make([]string, 0)
// make an interface to keep the record's value
record := make([]interface{}, len(columns))
for rows.Next() {
// get RawBytes from data
err = rows.Scan(scanArgs...)
if err != nil {
panic(err.Error()) // proper error handling instead of panic in your app
}
valueStrings = append(valueStrings, placeHolders)
for i, col := range values {
// you need to be carefull with the datatypes here
// check out the docs for details on here
record[i] = string(value)
}
bulkValues = append(bulkValues, record...)
}
stmStr := fmt.Sprintf("INSERT INTO city_copy2 VALUES %s", strings.Join(valueStrings, ","))
_, err = db.Exec(stmStr, bulkValues...)
if err != nil {
panic(err.Error())
}
}
You can also find the example of the documentation here.
Note: There might be more efficient ways to copy database from psql to mysql but this answer only gives a quick solution for this particular issue that you are having.

Query WMI from Go

I would like to run WMI queries from Go. There are ways to call DLL functions from Go. My understanding is that there must be some DLL somewhere which, with the correct call, will return some data I can parse and use. I'd prefer to avoid calling into C or C++, especially since I would guess those are wrappers over the Windows API itself.
I've examined the output of dumpbin.exe /exports c:\windows\system32\wmi.dll, and the following entry looks promising:
WmiQueryAllDataA (forwarded to wmiclnt.WmiQueryAllDataA)
However I'm not sure what to do from here. What arguments does this function take? What does it return? Searching for WmiQueryAllDataA is not helpful. And that name only appears in a comment of c:\program files (x86)\windows kits\8.1\include\shared\wmistr.h, but with no function signature.
Are there better methods? Is there another DLL? Am I missing something? Should I just use a C wrapper?
Running a WMI query in Linqpad with .NET Reflector shows the use of WmiNetUtilsHelper:ExecQueryWmi (and a _f version), but neither have a viewable implementation.
Update: use the github.com/StackExchange/wmi package which uses the solution in the accepted answer.
Welcome to the wonderful world of COM, Object Oriented Programming in C from when C++ was "a young upstart".
On github mattn has thrown together a little wrapper in Go, which I used to throw together a quick example program. "This repository was created for experimentation and should be considered unstable." instills all sorts of confidence.
I'm leaving out a lot of error checking. Trust me when I say, you'll want to add it back.
package main
import (
"github.com/mattn/go-ole"
"github.com/mattn/go-ole/oleutil"
)
func main() {
// init COM, oh yeah
ole.CoInitialize(0)
defer ole.CoUninitialize()
unknown, _ := oleutil.CreateObject("WbemScripting.SWbemLocator")
defer unknown.Release()
wmi, _ := unknown.QueryInterface(ole.IID_IDispatch)
defer wmi.Release()
// service is a SWbemServices
serviceRaw, _ := oleutil.CallMethod(wmi, "ConnectServer")
service := serviceRaw.ToIDispatch()
defer service.Release()
// result is a SWBemObjectSet
resultRaw, _ := oleutil.CallMethod(service, "ExecQuery", "SELECT * FROM Win32_Process")
result := resultRaw.ToIDispatch()
defer result.Release()
countVar, _ := oleutil.GetProperty(result, "Count")
count := int(countVar.Val)
for i :=0; i < count; i++ {
// item is a SWbemObject, but really a Win32_Process
itemRaw, _ := oleutil.CallMethod(result, "ItemIndex", i)
item := itemRaw.ToIDispatch()
defer item.Release()
asString, _ := oleutil.GetProperty(item, "Name")
println(asString.ToString())
}
}
The real meat is the call to ExecQuery, I happen to grab Win32_Process from the available classes because it's easy to understand and print.
On my machine, this prints:
System Idle Process
System
smss.exe
csrss.exe
wininit.exe
services.exe
lsass.exe
svchost.exe
svchost.exe
atiesrxx.exe
svchost.exe
svchost.exe
svchost.exe
svchost.exe
svchost.exe
spoolsv.exe
svchost.exe
AppleOSSMgr.exe
AppleTimeSrv.exe
... and so on
go.exe
main.exe
I'm not running it elevated or with UAC disabled, but some WMI providers are gonna require a privileged user.
I'm also not 100% that this won't leak a little, you'll want to dig into that. COM objects are reference counted, so defer should be a pretty good fit there (provided the method isn't crazy long running) but go-ole may have some magic inside I didn't notice.
I'm commenting over a year later, but there is a solution here on github (and posted below for posterity).
// +build windows
/*
Package wmi provides a WQL interface for WMI on Windows.
Example code to print names of running processes:
type Win32_Process struct {
Name string
}
func main() {
var dst []Win32_Process
q := wmi.CreateQuery(&dst, "")
err := wmi.Query(q, &dst)
if err != nil {
log.Fatal(err)
}
for i, v := range dst {
println(i, v.Name)
}
}
*/
package wmi
import (
"bytes"
"errors"
"fmt"
"log"
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/mattn/go-ole"
"github.com/mattn/go-ole/oleutil"
)
var l = log.New(os.Stdout, "", log.LstdFlags)
var (
ErrInvalidEntityType = errors.New("wmi: invalid entity type")
lock sync.Mutex
)
// QueryNamespace invokes Query with the given namespace on the local machine.
func QueryNamespace(query string, dst interface{}, namespace string) error {
return Query(query, dst, nil, namespace)
}
// Query runs the WQL query and appends the values to dst.
//
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
// the query must have the same name in dst. Supported types are all signed and
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
// Array types are not supported.
//
// By default, the local machine and default namespace are used. These can be
// changed using connectServerArgs. See
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || dv.IsNil() {
return ErrInvalidEntityType
}
dv = dv.Elem()
mat, elemType := checkMultiArg(dv)
if mat == multiArgTypeInvalid {
return ErrInvalidEntityType
}
lock.Lock()
defer lock.Unlock()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
if err != nil {
oleerr := err.(*ole.OleError)
// S_FALSE = 0x00000001 // CoInitializeEx was already called on this thread
if oleerr.Code() != ole.S_OK && oleerr.Code() != 0x00000001 {
return err
}
} else {
// Only invoke CoUninitialize if the thread was not initizlied before.
// This will allow other go packages based on go-ole play along
// with this library.
defer ole.CoUninitialize()
}
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
return err
}
defer unknown.Release()
wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
return err
}
defer wmi.Release()
// service is a SWbemServices
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
if err != nil {
return err
}
service := serviceRaw.ToIDispatch()
defer serviceRaw.Clear()
// result is a SWBemObjectSet
resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
if err != nil {
return err
}
result := resultRaw.ToIDispatch()
defer resultRaw.Clear()
count, err := oleInt64(result, "Count")
if err != nil {
return err
}
// Initialize a slice with Count capacity
dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
var errFieldMismatch error
for i := int64(0); i < count; i++ {
err := func() error {
// item is a SWbemObject, but really a Win32_Process
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
if err != nil {
return err
}
item := itemRaw.ToIDispatch()
defer itemRaw.Clear()
ev := reflect.New(elemType)
if err = loadEntity(ev.Interface(), item); err != nil {
if _, ok := err.(*ErrFieldMismatch); ok {
// We continue loading entities even in the face of field mismatch errors.
// If we encounter any other error, that other error is returned. Otherwise,
// an ErrFieldMismatch is returned.
errFieldMismatch = err
} else {
return err
}
}
if mat != multiArgTypeStructPtr {
ev = ev.Elem()
}
dv.Set(reflect.Append(dv, ev))
return nil
}()
if err != nil {
return err
}
}
return errFieldMismatch
}
// ErrFieldMismatch is returned when a field is to be loaded into a different
// type than the one it was stored from, or when a field is missing or
// unexported in the destination struct.
// StructType is the type of the struct pointed to by the destination argument.
type ErrFieldMismatch struct {
StructType reflect.Type
FieldName string
Reason string
}
func (e *ErrFieldMismatch) Error() string {
return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
e.FieldName, e.StructType, e.Reason)
}
var timeType = reflect.TypeOf(time.Time{})
// loadEntity loads a SWbemObject into a struct pointer.
func loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
v := reflect.ValueOf(dst).Elem()
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
isPtr := f.Kind() == reflect.Ptr
if isPtr {
ptr := reflect.New(f.Type().Elem())
f.Set(ptr)
f = f.Elem()
}
n := v.Type().Field(i).Name
if !f.CanSet() {
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "CanSet() is false",
}
}
prop, err := oleutil.GetProperty(src, n)
if err != nil {
errFieldMismatch = &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "no such struct field",
}
continue
}
defer prop.Clear()
switch val := prop.Value().(type) {
case int, int64:
var v int64
switch val := val.(type) {
case int:
v = int64(val)
case int64:
v = val
default:
panic("unexpected type")
}
switch f.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
f.SetInt(v)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
f.SetUint(uint64(v))
default:
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "not an integer class",
}
}
case string:
iv, err := strconv.ParseInt(val, 10, 64)
switch f.Kind() {
case reflect.String:
f.SetString(val)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if err != nil {
return err
}
f.SetInt(iv)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if err != nil {
return err
}
f.SetUint(uint64(iv))
case reflect.Struct:
switch f.Type() {
case timeType:
if len(val) == 25 {
mins, err := strconv.Atoi(val[22:])
if err != nil {
return err
}
val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
}
t, err := time.Parse("20060102150405.000000-0700", val)
if err != nil {
return err
}
f.Set(reflect.ValueOf(t))
}
}
case bool:
switch f.Kind() {
case reflect.Bool:
f.SetBool(val)
default:
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: "not a bool",
}
}
default:
typeof := reflect.TypeOf(val)
if isPtr && typeof == nil {
break
}
return &ErrFieldMismatch{
StructType: f.Type(),
FieldName: n,
Reason: fmt.Sprintf("unsupported type (%T)", val),
}
}
}
return errFieldMismatch
}
type multiArgType int
const (
multiArgTypeInvalid multiArgType = iota
multiArgTypeStruct
multiArgTypeStructPtr
)
// checkMultiArg checks that v has type []S, []*S for some struct type S.
//
// It returns what category the slice's elements are, and the reflect.Type
// that represents S.
func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
if v.Kind() != reflect.Slice {
return multiArgTypeInvalid, nil
}
elemType = v.Type().Elem()
switch elemType.Kind() {
case reflect.Struct:
return multiArgTypeStruct, elemType
case reflect.Ptr:
elemType = elemType.Elem()
if elemType.Kind() == reflect.Struct {
return multiArgTypeStructPtr, elemType
}
}
return multiArgTypeInvalid, nil
}
func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
v, err := oleutil.GetProperty(item, prop)
if err != nil {
return 0, err
}
defer v.Clear()
i := int64(v.Val)
return i, nil
}
// CreateQuery returns a WQL query string that queries all columns of src. where
// is an optional string that is appended to the query, to be used with WHERE
// clauses. In such a case, the "WHERE" string should appear at the beginning.
func CreateQuery(src interface{}, where string) string {
var b bytes.Buffer
b.WriteString("SELECT ")
s := reflect.Indirect(reflect.ValueOf(src))
t := s.Type()
if s.Kind() == reflect.Slice {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return ""
}
var fields []string
for i := 0; i < t.NumField(); i++ {
fields = append(fields, t.Field(i).Name)
}
b.WriteString(strings.Join(fields, ", "))
b.WriteString(" FROM ")
b.WriteString(t.Name())
b.WriteString(" " + where)
return b.String()
}
To access the winmgmts object or a namespace (which is the same), you can use the code below. Basically, you need to specify the namespace as parameter, which is not documented properly in go-ole.
In the code below, you can also see how to access a class within this namespace and execute a method.
package main
import (
"log"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
func main() {
ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
defer ole.CoUninitialize()
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
log.Panic(err)
}
defer unknown.Release()
wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
log.Panic(err)
}
defer wmi.Release()
// Connect to namespace
// root/PanasonicPC = winmgmts:\\.\root\PanasonicPC
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", nil, "root/PanasonicPC")
if err != nil {
log.Panic(err)
}
service := serviceRaw.ToIDispatch()
defer serviceRaw.Clear()
// Get class
setBiosRaw, err := oleutil.CallMethod(service, "Get", "SetBIOS4Conf")
if err != nil {
log.Panic(err)
}
setBios := setBiosRaw.ToIDispatch()
defer setBiosRaw.Clear()
// Run method
resultRaw, err := oleutil.CallMethod(setBios, "AccessAuthorization", "letmein")
resultVal := resultRaw.Value().(int32)
log.Println("Return Code:", resultVal)
}
import(
"os/exec"
)
​func​ (​lcu​ ​*​LCU​) ​GrabToken​() {
​        ​cmd​ ​:=​ ​exec​.​Command​(​"powershell"​, ​"$cmdline = Get-WmiObject -Class Win32_Process"​)
​        ​
​        ​out​, ​err​ ​:=​ ​cmd​.​CombinedOutput​()
​        ​if​ ​err​ ​!=​ ​nil​ {
​                ​fmt​.​Println​(​err​)
​        }
​        ​outstr​ ​:=​ ​string(out)
​}