Skip to content

Commit 952e10a

Browse files
authored
feat(model): generate ormbfile.yaml automatically (#170)
* feat(model): auto generate ormbfile.yaml * fix(model): fix comments
1 parent cd8d716 commit 952e10a

File tree

3 files changed

+162
-3
lines changed

3 files changed

+162
-3
lines changed

pkg/saver/saver.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/kleveross/ormb/pkg/consts"
1515
"github.com/kleveross/ormb/pkg/model"
1616
"github.com/kleveross/ormb/pkg/parser"
17+
"github.com/kleveross/ormb/pkg/util"
1718
)
1819

1920
// Saver is the implementation.
@@ -30,17 +31,36 @@ func New() Interface {
3031

3132
// Save saves the model from the path to the memory.
3233
func (d Saver) Save(path string) (*model.Model, error) {
34+
modelPath := filepath.Join(path, consts.ORMBModelDirectory)
35+
ormbfilePath := filepath.Join(path, consts.ORMBfileName)
36+
37+
if _, err := os.Stat(ormbfilePath); err != nil {
38+
if os.IsNotExist(err) {
39+
format, err := util.InferModelFormat(modelPath)
40+
if err != nil {
41+
return nil, err
42+
}
43+
44+
if format != "" {
45+
err := util.WriteORMBFile(ormbfilePath, format)
46+
if err != nil {
47+
return nil, err
48+
}
49+
}
50+
} else {
51+
return nil, err
52+
}
53+
}
54+
3355
// Save model config from <path>/ormbfile.yaml.
34-
dat, err := ioutil.ReadFile(filepath.Join(path, consts.ORMBfileName))
56+
dat, err := ioutil.ReadFile(ormbfilePath)
3557
if err != nil {
3658
return nil, err
3759
}
38-
3960
metadata := &model.Metadata{}
4061
if metadata, err = d.Parser.Parse(dat); err != nil {
4162
return nil, err
4263
}
43-
4464
format := model.Format(metadata.Format)
4565
if err := format.ValidateDirectory(path); err != nil {
4666
return nil, err

pkg/util/util.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package util
2+
3+
import (
4+
"io/ioutil"
5+
"os"
6+
"path"
7+
"strings"
8+
9+
"gopkg.in/yaml.v2"
10+
11+
ormbmodel "github.com/kleveross/ormb/pkg/model"
12+
)
13+
14+
// InferModelFormat infers model format by files' ext.
15+
func InferModelFormat(dir string) (ormbmodel.Format, error) {
16+
fileList, err := ioutil.ReadDir(dir)
17+
if err != nil {
18+
return "", err
19+
}
20+
21+
netdefFileNum := 0
22+
mxnetFileNum := 0
23+
24+
for _, file := range fileList {
25+
if file.IsDir() {
26+
continue
27+
}
28+
29+
fileExtName := strings.Trim(path.Ext(file.Name()), ".")
30+
switch fileExtName {
31+
case "pb":
32+
if strings.HasSuffix(file.Name(), "saved_model.pb") {
33+
return ormbmodel.FormatSavedModel, nil
34+
} else if strings.HasSuffix(file.Name(), "init_net.pb") || strings.HasSuffix(file.Name(), "predict_net.pb") {
35+
netdefFileNum++
36+
}
37+
case "onnx":
38+
return ormbmodel.FormatONNX, nil
39+
case "graphdef":
40+
return ormbmodel.FormatGraphDef, nil
41+
case "caffemodel":
42+
return ormbmodel.FormatCaffeModel, nil
43+
case "pt":
44+
return ormbmodel.FormatTorchScript, nil
45+
case "plan", "engine":
46+
return ormbmodel.FormatTensorRT, nil
47+
case "pmml":
48+
return ormbmodel.FormatPMML, nil
49+
case "params":
50+
mxnetFileNum++
51+
case "json":
52+
if strings.HasSuffix(file.Name(), "symbol.json") {
53+
mxnetFileNum++
54+
}
55+
case "h5":
56+
return ormbmodel.FormatH5, nil
57+
case "xgboost":
58+
return ormbmodel.FormatXGBoost, nil
59+
case "joblib":
60+
return ormbmodel.FormatSKLearn, nil
61+
}
62+
}
63+
64+
if netdefFileNum == 2 {
65+
return ormbmodel.FormatNetDef, nil
66+
}
67+
68+
if mxnetFileNum == 2 {
69+
return ormbmodel.FormatMXNetParams, nil
70+
}
71+
72+
return ormbmodel.FormatOthers, nil
73+
}
74+
75+
// WriteORMBFile write ormbfile.yaml if file is not exist.
76+
func WriteORMBFile(filePath string, format ormbmodel.Format) error {
77+
metadata := &ormbmodel.Metadata{
78+
Format: string(format),
79+
}
80+
data, err := yaml.Marshal(metadata)
81+
if err != nil {
82+
return err
83+
}
84+
85+
f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
86+
if err != nil {
87+
return err
88+
}
89+
defer f.Close() // nolint
90+
91+
_, err = f.Write(data)
92+
if err != nil {
93+
return err
94+
}
95+
96+
return nil
97+
}

pkg/util/util_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package util
2+
3+
import (
4+
"os"
5+
"path"
6+
"testing"
7+
8+
ormbmodel "github.com/kleveross/ormb/pkg/model"
9+
)
10+
11+
func TestWriteORMBFile(t *testing.T) {
12+
cwd, _ := os.Getwd()
13+
14+
ormbfilePath := path.Join(cwd, "ormbfile.yaml")
15+
defer os.RemoveAll(ormbfilePath)
16+
17+
type args struct {
18+
filePath string
19+
format ormbmodel.Format
20+
}
21+
tests := []struct {
22+
name string
23+
args args
24+
wantErr bool
25+
}{
26+
{
27+
name: "WriteORMBFile",
28+
args: args{
29+
filePath: ormbfilePath,
30+
format: ormbmodel.FormatMXNetParams,
31+
},
32+
wantErr: false,
33+
},
34+
}
35+
for _, tt := range tests {
36+
t.Run(tt.name, func(t *testing.T) {
37+
if err := WriteORMBFile(tt.args.filePath, tt.args.format); (err != nil) != tt.wantErr {
38+
t.Errorf("WriteORMBFile() error = %v, wantErr %v", err, tt.wantErr)
39+
}
40+
})
41+
}
42+
}

0 commit comments

Comments
 (0)