新手问题 golang的嵌套事务管理

Alber · 2018年06月18日 · 76 次阅读

golang 的事务管理是一件很麻烦的事,,能不能像 Java 那样,通过 Spring 管理事务,最近琢磨了一下,写了一个 demo,用来管理 golang 的事务,使其支持 golang 事务的嵌套调用。

其思想很简单,对于所有的写数据库操作,用一个标记来标记事务的开启和关闭 下面是一个演示示例:

我只是写了一个简单 demo,这里贴出实现代码:

package session

import (
    "database/sql"
)

const beginStatus = 1

// SessionFactory 会话工厂
type SessionFactory struct {
    *sql.DB
}

// Session 会话
type Session struct {
    db           *sql.DB // 原生db
    tx           *sql.Tx // 原生事务
    commitSign   int8    // 提交标记,控制是否提交事务
    rollbackSign bool    // 回滚标记,控制是否回滚事务
}

// NewSessionFactory 创建一个会话工厂
func NewSessionFactory(driverName, dataSourseName string) (*SessionFactory, error) {
    db, err := sql.Open(driverName, dataSourseName)
    if err != nil {
        panic(err)
    }
    factory := new(SessionFactory)
    factory.DB = db
    return factory, nil
}

// GetSession 获取一个Session
func (sf *SessionFactory) GetSession() *Session {
    session := new(Session)
    session.db = sf.DB
    return session
}

// Begin 开启事务
func (s *Session) Begin() error {
    s.rollbackSign = true
    if s.tx == nil {
        tx, err := s.db.Begin()
        if err != nil {
            return err
        }
        s.tx = tx
        s.commitSign = beginStatus
        return nil
    }
    s.commitSign++
    return nil
}

// Rollback 回滚事务
func (s *Session) Rollback() error {
    if s.tx != nil && s.rollbackSign == true {
        err := s.tx.Rollback()
        if err != nil {
            return err
        }
        s.tx = nil
        return nil
    }
    return nil
}

// Commit 提交事务
func (s *Session) Commit() error {
    s.rollbackSign = false
    if s.tx != nil {
        if s.commitSign == beginStatus {
            err := s.tx.Commit()
            if err != nil {
                return err
            }
            s.tx = nil
            return nil
        } else {
            s.commitSign--
        }
        return nil
    }
    return nil
}

// Exec 执行sql语句,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
func (s *Session) Exec(query string, args ...interface{}) (sql.Result, error) {
    if s.tx != nil {
        return s.tx.Exec(query, args...)
    }
    return s.db.Exec(query, args...)
}

// QueryRow 如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
func (s *Session) QueryRow(query string, args ...interface{}) *sql.Row {
    if s.tx != nil {
        return s.tx.QueryRow(query, args...)
    }
    return s.db.QueryRow(query, args...)
}

// Query 查询数据,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
func (s *Session) Query(query string, args ...interface{}) (*sql.Rows, error) {
    if s.tx != nil {
        return s.tx.Query(query, args...)
    }
    return s.db.Query(query, args...)
}

// Prepare 预执行,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
func (s *Session) Prepare(query string) (*sql.Stmt, error) {
    if s.tx != nil {
        return s.tx.Prepare(query)
    }
    return s.db.Prepare(query)
}

测试用例:

package session

import (
    _ "github.com/go-sql-driver/mysql"
    "testing"
    "fmt"
)

var sf *SessionFactory

func init() {
    var err error
    sf, err = NewSessionFactory("mysql", "root:Liu123456@tcp(localhost:3306)/test?charset=utf8")
    if err != nil {
        fmt.Println(err)
        panic(err)
    }
}

type User struct {
    mobile string
    name   string
    age    int
    sex    int
}

type UserService struct {
    session *Session
}

func NewUserService() *UserService {
    return &UserService{sf.GetSession()}
}

func (s *UserService) Insert(user User) error {
    _, err := s.session.Exec("insert into user(mobile,name,age,sex) values(?,?,?,?)",
        user.mobile, user.name, user.age, user.sex)
    return err
}

func (s *UserService) AddInTx(user1, user2 User) error {
    err := s.session.Begin()
    if err != nil {
        return err
    }
    defer s.session.Rollback()

    err = s.Insert(user1)
    if err != nil {
        fmt.Println(err)
        return err
    }

    // return errors.New("err") 回滚测试

    err = s.Insert(user2)
    if err != nil {
        fmt.Println(err)
        return err
    }

    s.session.Commit()
    return nil
}

// DoNestingTx 嵌套事务
func (s *UserService) DoNestingTx() {
    err := s.session.Begin()
    if err != nil {
        fmt.Println(err)
    }
    defer s.session.Rollback()

    err = s.Insert(User{mobile: "1", name: "1", age: 1, sex: 1})
    if err != nil {
        fmt.Println(err)
        return
    }

    err = s.AddInTx(User{mobile: "1", name: "1", age: 1, sex: 1}, User{mobile: "1", name: "1", age: 1, sex: 1})
    if err != nil {
        fmt.Println(err)
        return
    }

    err = s.session.Commit()
    if err != nil {
        fmt.Println(err)
        return
    }
}

// TestDoNestingTx 测试嵌套事务
func TestDoNestingTx(t *testing.T) {
    userService := NewUserService()
    userService.DoNestingTx()
}

GitHub:https://github.com/alberliu/session

更多原创文章干货分享,请关注公众号
  • 加微信实战群请加微信(注明:实战群):gocnio
暂无回复。
需要 登录 后方可回复, 如果你还没有账号请点击这里 注册