MySQL建表语句生成Golang代码

1. 背景

对于后台开发新的需求时,一般会先进行各种表的设计,写各个表的建表语句

然后根据建立的表,写对应的model代码、基础的增删改查代码(基础的增删改查服务可以划入DAO(Data Access Object)层)。

model代码都有一些固定的格式,可以通过解析SQL建表语句,来自动生成model代码,

对于不同的表,基础的增删改查代码大概率只是换了个表名或者数据库,因此也可以自动生成。

通过自动生成代码,减少重复工作,提示开发效率。

 

2. 整体介绍

目录结构如下,具体代码建Github sql2code

. ├── README.md ├── code2file │   └── code2file.go ├── go.mod ├── go.sum ├── main.go ├── sql2code_tpl │   ├── sql2code_tpl.go │   ├── sql2dao_tpl.txt │   └── sql2model_tpl.txt ├── sql2dao │   ├── sql2dao.go │   └── sql2dao_test.go ├── sql2model │   ├── sql2model.go │   ├── sql2model_test.go │   └── tidb_types.go ├── test │   └── t_student.sql └── util     └── util_strings         └── util_strings.go  7 directories, 15 files

sql2code_tpl: 主要是model和dao的模版代码

sql2model:MySQL建表语句到Model代码的主要处理流程

sql2dao:MySQL建表语句到Dao代码到主要处理流程

有误go get TiDB的types文件夹会出现各种冲突的错误,因此只拷贝的需要的部分代码到sql2model/tidb_types.go

 

3. sql到model

3.1 基本思路

使用SQL解析器获得表名及每一列信息,使用模版生成model代码。

SQL解析器使用的是TiDB parser

3.2 model模版

package {{.PackageName}}  // {{.Comment}} type {{.ModelName}} struct { {{- range .Rows}} 	{{.Name}} {{.GoType}} {{.Tags}} // {{.Comment}} {{- end}} }  func ({{.ModelName}}) TableName() string { 	return "{{.OriginTblName}}" }

3.3 部分实现代码

使用SQL解析器解析建表语句,获得表名,及每一列的列名,注释等信息,主要代码如下

func SQLParse(sql, tablePrefix string) (*ModelTable, error) { 	cts, err := parseCreateTableStmt(sql) 	if err != nil { 		log.Printf("parseCreateTableStmt fail,err:%v", err) 		return nil, err 	} 	mt := &ModelTable{} 	primaryKey := ""  	// table name 	tblName := TableNamePrefixCut(cts.Table.Name.L, tablePrefix) 	mt.ModelName = TableName2ModelName(tblName) 	mt.OriginTblName = cts.Table.Name.L  	// primary 	for _, ctt := range cts.Constraints { 		// only contain one primary key 		if ctt.Tp == ast.ConstraintPrimaryKey { 			if len(ctt.Keys) >= 0 { 				primaryKey = ctt.Keys[0].Column.Name.L 			} 			break 		} 	}  	// comment 	for _, op := range cts.Options { 		if op.Tp == ast.TableOptionComment { 			mt.Comment = op.StrValue 		} 	}  	modelRows := make([]ModelRow, 0, len(cts.Cols)) 	for _, col := range cts.Cols { 		nameLow := col.Name.Name.L 		modelRow := ModelRow{ 			Name: util_strings.ToCamel(col.Name.Name.L), // 需要去除下划线转驼峰 		} 		//fmt.Printf("col: %+v %+v %v %vn", col.Name, col.Tp, HasUnsignedFlag(col.Tp.GetFlag()), col.Tp.GetType()) 		modelRow.GoType = sqlType2GoType(col.Tp.GetType(), col.Tp.GetFlag()) 		for _, op := range col.Options { 			if op.Tp == ast.ColumnOptionComment { 				exprVal, ok := op.Expr.(*test_driver.ValueExpr) 				if !ok { 					fmt.Println("op.Expr.(*test_driver.ValueExpr) fail.") 					continue 				} 				modelRow.Comment = exprVal.Datum.GetString() 				break 			} 		} 		if primaryKey == col.Name.Name.L { 			modelRow.Tags = fmt.Sprintf("`gorm:"column:%v; primary_key" json:"%v"`", nameLow, nameLow) 		} else { 			modelRow.Tags = fmt.Sprintf("`gorm:"column:%v;" json:"%v"`", nameLow, nameLow) 		} 		//fmt.Println(modelRow.Tags) 		modelRows = append(modelRows, modelRow) 	} 	mt.Rows = modelRows 	return mt, nil }

3. sql到dao

2.1 基本思路

  • 获得MySQL建表语句的表名信息
  • 使用模版生成CRUD代码

2.2 模版代码

查看代码
 package {{.PackageName}}  {{template "addTemplate" .}} {{template "deleteTemplate" .}} {{template "updateTemplate" .}} {{template "getMultiTemplate" .}} {{template "getCountTemplate" .}} {{template "getOneTemplate" .}}   {{define "addTemplate"}} func Add{{.ModelName}}(ctx context.Context,obj *{{.ModelPackage}}.{{.ModelName}}, whereMap map[string]interface{}) (error, int64) { 	if whereMap != nil { 		err, existObj := GetOne{{.ModelName}}(ctx, whereMap) 		if err != nil { 			log.Printf("[Add{{.ModelName}}]GetOne{{.ModelName}} fail, err:%v, obj:%v", err, obj) 			return err, int64(0) 		} 		if existObj != nil && existObj.AddTime > int64(0) { 			logs.CtxInfo(ctx, "[Add{{.ModelName}}] {{.ModelName}} exist, existsObj:%v", existObj) 			return nil, existObj.ID 		} 	}  	if obj.AddTime <= 0 { 		obj.AddTime = util_datetime.CurrentMS() 	} 	if obj.UpdateTime <= 0 { 		obj.UpdateTime = util_datetime.CurrentMS() 	}  	res := {{.DBConect}}.Create(obj) 	if res.Error != nil { 		log.Printf("[Add{{.ModelName}}]Add{{.ModelName}} fail, err:%v, obj:%v", res.Error, obj) 		return res.Error, int64(0) 	} 	return res.Error, obj.ID }  {{end}}  {{define "deleteTemplate"}} func Delete{{.ModelName}}(ctx context.Context,whereMap map[string]interface{}) (error, int64) { 	query := db.WhereQuery({{.DBConect}}, whereMap) 	res := query.Delete(&{{.ModelPackage}}.{{.ModelName}}{}) 	if res.Error != nil { 		log.Printf("Delete{{.ModelName}} failed, err:%v, whereMap:%v", res.Error, whereMap) 		return res.Error, int64(0) 	} 	rowsAffected := res.RowsAffected 	return nil, rowsAffected }  {{end}}  {{define "updateTemplate"}} func Update{{.ModelName}}(ctx context.Context, whereMap map[string]interface{}, setMap map[string]interface{}) (error, int64) { 	obj := &{{.ModelPackage}}.{{.ModelName}}{} 	query := {{.DBConect}}.Model(obj) 	query = db.WhereQuery(query, whereMap) 	if updateTime, ok := setMap["update_time"]; !ok || updateTime.(int64) <= 0 { 		setMap["update_time"] = util_datetime.CurrentMS() 	} 	res := query.Updates(setMap) 	if res.Error != nil { 		log.Printf("[Update{{.ModelName}}]Update{{.ModelName}} fail, err:%v, whereMap:%v, setMap:%v", res.Error, whereMap, setMap) 		return res.Error, int64(0) 	} 	rowsAffected := res.RowsAffected 	return nil, rowsAffected }  {{end}}  {{define "getMultiTemplate"}} func GetMulti{{.ModelName}}s(ctx context.Context,whereMap map[string]interface{}, offset, limit int64, orderBy, groupby, fields string) (error, []*{{.ModelPackage}}.{{.ModelName}}) { 	objs := []*{{.ModelPackage}}.{{.ModelName}}{} 	query := db.WhereQuery({{.DBConect}}, whereMap) 	query = db.OrderByQuery(query, orderBy) 	query = db.FieldsQuery(query, fields) 	query = db.GroupByQuery(query, groupby) 	query = db.LimitQuery(query, offset, limit)  	res := query.Find(&objs) 	if res.Error != nil { 		if res.Error.Error() == "record not found" { 			return nil, nil 		} 		log.Printf("[GetMulti{{.ModelName}}s]GetMulti{{.ModelName}}s fail, err:%v", res.Error) 	} 	return res.Error, objs }  {{end}}  {{define "getCountTemplate"}} func GetMulti{{.ModelName}}sCount(ctx context.Context,whereMap map[string]interface{}) (error, int64) { 	cnt := int64(0) 	query := db.WhereQuery({{.DBConect}}, whereMap)  	res := query.Model(&{{.ModelPackage}}.{{.ModelName}}{}).Count(&cnt) 	if res.Error != nil { 		if res.Error.Error() == "record not found" { 			return nil, cnt 		} 		log.Printf("GetMulti{{.ModelName}}sCount fail, err:%v", res.Error) 	} 	return res.Error, cnt }  {{end}}  {{define "getOneTemplate"}} func GetOne{{.ModelName}}(ctx context.Context,whereMap map[string]interface{},fields string)(error, *{{.ModelPackage}}.{{.ModelName}}) { 	err, objs := GetMulti{{.ModelName}}s(ctx, whereMap, 0, 1, "", "", fields) 	if err != nil { 		return err, nil 	} 	if len(objs) >= 1 { 		return nil, objs[0] 	} 	return nil, nil }  {{end}}

2.2 部分实现代码

主要是根据表名获取对应的model名称、包名等,再利用模版生成代码。

func SQL2Dao(sql string, tablePrefix, packagePrefix, dbCon string) (string, error) { 	tblName, err := sql2model.TableNameGetFromSQL(sql, tablePrefix) 	if err != nil { 		return "", err 	}  	modelPackage := sql2model.ModelPackageGet(tblName, tablePrefix, packagePrefix) 	packageName := DaoPackageNameGet(tblName, tablePrefix, packagePrefix)  	df := DaoFile{ 		PackageName:  packageName, 		ModelPackage: modelPackage, 		ModelName:    sql2model.TableName2ModelName(tblName), 		DBConect:     dbCon, 	} 	return daoFileGen(df) }

4. 使用方式

Usage of this program:   -dbcon string         db connect name   -if string         File path of the SQL statement that creates the table   -op int         1:gen model code 2:gen dao code 3:both (default 1)   -pp string         package prefix add for go file   -sql string         SQL statement to create table   -tp string         table prefix of table name to cut

5. 使用实例

$ go run main.go -if=./test/t_student.sql -dbcon=UserDB -tp="t_" -pp=user -op=3 model code have been write to  ./output/user_student.go model code have been write to  ./output/user_student_service.go

6. 完整代码

Github sql2code

5. 参考

1. TiDB parser quickstart

 

发表评论

评论已关闭。

相关文章

当前内容话题