Skip to content

Commit 6636539

Browse files
authored
BigQuery: Augment BQ Model metadata. (#5214)
* BigQuery: Augment BQ ML model metadata. This change introduces extended ML model metadata into the existing CRUD implementation for BigQuery ML models. This implementation exposes the underlying discovery types directly without an abstraction for the ML metadata. Methods that provide this metadata are all marked @BetaApi to indicate the getters may change.
1 parent 88f539a commit 6636539

5 files changed

Lines changed: 122 additions & 1 deletion

File tree

google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ enum ModelField implements FieldSelector {
132132
LAST_MODIFIED_TIME("lastModifiedTime"),
133133
LOCATION("location"),
134134
MODEL_REFERENCE("modelReference"),
135-
TIME_PARTITIONING("timePartitioning"),
135+
TRAINING_RUNS("trainingRuns"),
136+
LABEL_COLUMNS("labelColumns"),
137+
FEATURE_COLUMNS("featureColumns"),
136138
TYPE("modelType");
137139

138140
static final List<? extends FieldSelector> REQUIRED_FIELDS = ImmutableList.of(MODEL_REFERENCE);

google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818

1919
import static com.google.common.base.Preconditions.checkNotNull;
2020

21+
import com.google.api.services.bigquery.model.StandardSqlField;
22+
import com.google.api.services.bigquery.model.TrainingRun;
2123
import com.google.cloud.bigquery.BigQuery.ModelOption;
2224
import java.io.IOException;
2325
import java.io.ObjectInputStream;
26+
import java.util.List;
2427
import java.util.Map;
2528
import java.util.Objects;
2629

@@ -106,6 +109,24 @@ public Builder setLabels(Map<String, String> labels) {
106109
return this;
107110
}
108111

112+
@Override
113+
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
114+
infoBuilder.setTrainingRuns(trainingRunList);
115+
return this;
116+
}
117+
118+
@Override
119+
Builder setLabelColumns(List<StandardSqlField> labelColumnList) {
120+
infoBuilder.setLabelColumns(labelColumnList);
121+
return this;
122+
}
123+
124+
@Override
125+
Builder setFeatureColumns(List<StandardSqlField> featureColumnList) {
126+
infoBuilder.setFeatureColumns(featureColumnList);
127+
return this;
128+
}
129+
109130
public Model build() {
110131
return new Model(bigquery, infoBuilder);
111132
}

google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@
1818

1919
import static com.google.common.base.Preconditions.checkNotNull;
2020

21+
import com.google.api.core.BetaApi;
2122
import com.google.api.services.bigquery.model.Model;
23+
import com.google.api.services.bigquery.model.StandardSqlField;
24+
import com.google.api.services.bigquery.model.TrainingRun;
2225
import com.google.common.base.Function;
2326
import com.google.common.base.MoreObjects;
2427
import com.google.common.base.Strings;
28+
import com.google.common.collect.ImmutableList;
2529
import java.io.Serializable;
30+
import java.util.Collections;
31+
import java.util.List;
2632
import java.util.Map;
2733
import java.util.Objects;
2834

@@ -62,6 +68,9 @@ public Model apply(ModelInfo ModelInfo) {
6268
private final Long lastModifiedTime;
6369
private final Long expirationTime;
6470
private final Labels labels;
71+
private final ImmutableList<TrainingRun> trainingRunList;
72+
private final ImmutableList<StandardSqlField> featureColumnList;
73+
private final ImmutableList<StandardSqlField> labelColumnList;
6574

6675
/** A builder for {@code ModelInfo} objects. */
6776
public abstract static class Builder {
@@ -97,6 +106,12 @@ public abstract static class Builder {
97106

98107
abstract Builder setLastModifiedTime(Long lastModifiedTime);
99108

109+
abstract Builder setTrainingRuns(List<TrainingRun> trainingRunList);
110+
111+
abstract Builder setLabelColumns(List<StandardSqlField> labelColumnList);
112+
113+
abstract Builder setFeatureColumns(List<StandardSqlField> featureColumnList);
114+
100115
/** Creates a {@code ModelInfo} object. */
101116
public abstract ModelInfo build();
102117
}
@@ -112,6 +127,9 @@ static class BuilderImpl extends Builder {
112127
private Long lastModifiedTime;
113128
private Long expirationTime;
114129
private Labels labels = Labels.ZERO;
130+
private List<TrainingRun> trainingRunList = Collections.emptyList();
131+
private List<StandardSqlField> labelColumnList = Collections.emptyList();
132+
private List<StandardSqlField> featureColumnList = Collections.emptyList();
115133

116134
BuilderImpl() {}
117135

@@ -124,6 +142,9 @@ static class BuilderImpl extends Builder {
124142
this.creationTime = modelInfo.creationTime;
125143
this.lastModifiedTime = modelInfo.lastModifiedTime;
126144
this.expirationTime = modelInfo.expirationTime;
145+
this.trainingRunList = modelInfo.trainingRunList;
146+
this.labelColumnList = modelInfo.labelColumnList;
147+
this.featureColumnList = modelInfo.featureColumnList;
127148
}
128149

129150
BuilderImpl(Model modelPb) {
@@ -139,6 +160,15 @@ static class BuilderImpl extends Builder {
139160
this.lastModifiedTime = modelPb.getLastModifiedTime();
140161
this.expirationTime = modelPb.getExpirationTime();
141162
this.labels = Labels.fromPb(modelPb.getLabels());
163+
if (modelPb.getTrainingRuns() != null) {
164+
this.trainingRunList = modelPb.getTrainingRuns();
165+
}
166+
if (modelPb.getLabelColumns() != null) {
167+
this.labelColumnList = modelPb.getLabelColumns();
168+
}
169+
if (modelPb.getFeatureColumns() != null) {
170+
this.featureColumnList = modelPb.getFeatureColumns();
171+
}
142172
}
143173

144174
@Override
@@ -195,6 +225,24 @@ public Builder setLabels(Map<String, String> labels) {
195225
return this;
196226
}
197227

228+
@Override
229+
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
230+
this.trainingRunList = checkNotNull(trainingRunList);
231+
return this;
232+
}
233+
234+
@Override
235+
Builder setLabelColumns(List<StandardSqlField> labelColumnList) {
236+
this.labelColumnList = checkNotNull(labelColumnList);
237+
return this;
238+
}
239+
240+
@Override
241+
Builder setFeatureColumns(List<StandardSqlField> featureColumnList) {
242+
this.featureColumnList = checkNotNull(featureColumnList);
243+
return this;
244+
}
245+
198246
@Override
199247
public ModelInfo build() {
200248
return new ModelInfo(this);
@@ -211,6 +259,9 @@ public ModelInfo build() {
211259
this.lastModifiedTime = builder.lastModifiedTime;
212260
this.expirationTime = builder.expirationTime;
213261
this.labels = builder.labels;
262+
this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList);
263+
this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList);
264+
this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList);
214265
}
215266

216267
/** Returns the hash of the model resource. */
@@ -261,6 +312,24 @@ public Map<String, String> getLabels() {
261312
return labels.userMap();
262313
}
263314

315+
/** Returns metadata about each training run iteration. */
316+
@BetaApi
317+
public ImmutableList<TrainingRun> getTrainingRuns() {
318+
return trainingRunList;
319+
}
320+
321+
/** Returns information about the label columns for this model. */
322+
@BetaApi
323+
public ImmutableList<StandardSqlField> getLabelColumns() {
324+
return labelColumnList;
325+
}
326+
327+
/** Returns information about the feature columns for this model. */
328+
@BetaApi
329+
public ImmutableList<StandardSqlField> getFeatureColumns() {
330+
return featureColumnList;
331+
}
332+
264333
public Builder toBuilder() {
265334
return new BuilderImpl(this);
266335
}
@@ -277,6 +346,9 @@ public String toString() {
277346
.add("lastModifiedTime", lastModifiedTime)
278347
.add("expirationTime", expirationTime)
279348
.add("labels", labels)
349+
.add("trainingRuns", trainingRunList)
350+
.add("labelColumns", labelColumnList)
351+
.add("featureColumns", featureColumnList)
280352
.toString();
281353
}
282354

@@ -321,6 +393,9 @@ Model toPb() {
321393
modelPb.setLastModifiedTime(lastModifiedTime);
322394
modelPb.setExpirationTime(expirationTime);
323395
modelPb.setLabels(labels.toPb());
396+
modelPb.setTrainingRuns(trainingRunList);
397+
modelPb.setLabelColumns(labelColumnList);
398+
modelPb.setFeatureColumns(featureColumnList);
324399
return modelPb;
325400
}
326401

google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import static org.junit.Assert.assertEquals;
1919
import static org.junit.Assert.assertNull;
2020

21+
import com.google.api.services.bigquery.model.TrainingOptions;
22+
import com.google.api.services.bigquery.model.TrainingRun;
23+
import java.util.Arrays;
24+
import java.util.List;
2125
import org.junit.Test;
2226

2327
public class ModelInfoTest {
@@ -30,6 +34,12 @@ public class ModelInfoTest {
3034
private static final String DESCRIPTION = "description";
3135
private static final String FRIENDLY_NAME = "friendlyname";
3236

37+
private static final TrainingOptions TRAINING_OPTIONS =
38+
new TrainingOptions().setDataSplitColumn("foo").setEarlyStop(true).setLossType("bar");
39+
private static final TrainingRun TRAINING_RUN =
40+
new TrainingRun().setTrainingOptions(TRAINING_OPTIONS);
41+
private static final List<TrainingRun> TRAINING_RUN_LIST = Arrays.asList(TRAINING_RUN);
42+
3343
private static final ModelInfo MODEL_INFO =
3444
ModelInfo.newBuilder(MODEL_ID)
3545
.setEtag(ETAG)
@@ -38,6 +48,7 @@ public class ModelInfoTest {
3848
.setLastModifiedTime(LAST_MODIFIED_TIME)
3949
.setDescription(DESCRIPTION)
4050
.setFriendlyName(FRIENDLY_NAME)
51+
.setTrainingRuns(TRAINING_RUN_LIST)
4152
.build();
4253

4354
@Test
@@ -59,6 +70,7 @@ public void testBuilder() {
5970
assertEquals(EXPIRATION_TIME, MODEL_INFO.getExpirationTime());
6071
assertEquals(DESCRIPTION, MODEL_INFO.getDescription());
6172
assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName());
73+
assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions());
6274
}
6375

6476
@Test
@@ -71,6 +83,9 @@ public void testOf() {
7183
assertNull(modelInfo.getExpirationTime());
7284
assertNull(modelInfo.getDescription());
7385
assertNull(modelInfo.getFriendlyName());
86+
assertEquals(modelInfo.getTrainingRuns().isEmpty(), true);
87+
assertEquals(modelInfo.getLabelColumns().isEmpty(), true);
88+
assertEquals(modelInfo.getFeatureColumns().isEmpty(), true);
7489
}
7590

7691
@Test
@@ -94,5 +109,8 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) {
94109
assertEquals(expected.getFriendlyName(), value.getFriendlyName());
95110
assertEquals(expected.getLabels(), value.getLabels());
96111
assertEquals(expected.hashCode(), value.hashCode());
112+
assertEquals(expected.getTrainingRuns(), value.getTrainingRuns());
113+
assertEquals(expected.getLabelColumns(), value.getLabelColumns());
114+
assertEquals(expected.getFeatureColumns(), value.getFeatureColumns());
97115
}
98116
}

google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,11 @@ public void testModelLifecycle() throws InterruptedException {
10831083
Model model = bigquery.getModel(modelId);
10841084
assertNotNull(model);
10851085
assertEquals(model.getModelType(), "LINEAR_REGRESSION");
1086+
// Compare the extended model metadata.
1087+
assertEquals(model.getFeatureColumns().get(0).getName(), "f1");
1088+
assertEquals(model.getLabelColumns().get(0).getName(), "predicted_label");
1089+
assertEquals(
1090+
model.getTrainingRuns().get(0).getTrainingOptions().getLearnRateStrategy(), "CONSTANT");
10861091

10871092
// Mutate metadata.
10881093
ModelInfo info = model.toBuilder().setDescription("TEST").build();

0 commit comments

Comments
 (0)