Skip to content

Commit 3774807

Browse files
authored
---
yaml --- r: 35229 b: refs/heads/autosynth-websecurityscanner c: 6636539 h: refs/heads/master i: 35227: 5d1418e
1 parent 418fba9 commit 3774807

6 files changed

Lines changed: 123 additions & 2 deletions

File tree

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ refs/heads/autosynth-speech: c563dcd420cce0a37c39b1b9c24be1b9ba604dc7
142142
refs/heads/autosynth-tasks: 25d1eafe8cb66b00e3dad765dac74a5b45b83e63
143143
refs/heads/autosynth-texttospeech: 7a3ad430dddaed7a76f2026064502680c9339915
144144
refs/heads/autosynth-trace: 31564421a4b29f8257a6daea7f9a19838ac6459f
145-
refs/heads/autosynth-websecurityscanner: 88f539a363feff0b44b7b5318e8cd1c425702c89
145+
refs/heads/autosynth-websecurityscanner: 6636539f0d61385b561e89b2a7dc25bd989747bd
146146
refs/heads/bigquerystorage: 06db74d123d7f8a3ef48755c2fcabed09faf8e64
147147
refs/heads/elharo-patch-1: ce159ef828d3c545991ff78e7b6e0d912a9453e9
148148
refs/heads/snyk-fix-r0punm: 1f0e6519ffd9f6cc09bcce1ccdf3fb61b6f4f9b5

branches/autosynth-websecurityscanner/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ enum ModelField implements FieldSelector {
132132
LAST_MODIFIED_TIME("lastModifiedTime"),
133133
LOCATION("location"),
134134
MODEL_REFERENCE("modelReference"),
135-
TIME_PARTITIONING("timePartitioning"),
135+
TRAINING_RUNS("trainingRuns"),
136+
LABEL_COLUMNS("labelColumns"),
137+
FEATURE_COLUMNS("featureColumns"),
136138
TYPE("modelType");
137139

138140
static final List<? extends FieldSelector> REQUIRED_FIELDS = ImmutableList.of(MODEL_REFERENCE);

branches/autosynth-websecurityscanner/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818

1919
import static com.google.common.base.Preconditions.checkNotNull;
2020

21+
import com.google.api.services.bigquery.model.StandardSqlField;
22+
import com.google.api.services.bigquery.model.TrainingRun;
2123
import com.google.cloud.bigquery.BigQuery.ModelOption;
2224
import java.io.IOException;
2325
import java.io.ObjectInputStream;
26+
import java.util.List;
2427
import java.util.Map;
2528
import java.util.Objects;
2629

@@ -106,6 +109,24 @@ public Builder setLabels(Map<String, String> labels) {
106109
return this;
107110
}
108111

112+
@Override
113+
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
114+
infoBuilder.setTrainingRuns(trainingRunList);
115+
return this;
116+
}
117+
118+
@Override
119+
Builder setLabelColumns(List<StandardSqlField> labelColumnList) {
120+
infoBuilder.setLabelColumns(labelColumnList);
121+
return this;
122+
}
123+
124+
@Override
125+
Builder setFeatureColumns(List<StandardSqlField> featureColumnList) {
126+
infoBuilder.setFeatureColumns(featureColumnList);
127+
return this;
128+
}
129+
109130
public Model build() {
110131
return new Model(bigquery, infoBuilder);
111132
}

branches/autosynth-websecurityscanner/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@
1818

1919
import static com.google.common.base.Preconditions.checkNotNull;
2020

21+
import com.google.api.core.BetaApi;
2122
import com.google.api.services.bigquery.model.Model;
23+
import com.google.api.services.bigquery.model.StandardSqlField;
24+
import com.google.api.services.bigquery.model.TrainingRun;
2225
import com.google.common.base.Function;
2326
import com.google.common.base.MoreObjects;
2427
import com.google.common.base.Strings;
28+
import com.google.common.collect.ImmutableList;
2529
import java.io.Serializable;
30+
import java.util.Collections;
31+
import java.util.List;
2632
import java.util.Map;
2733
import java.util.Objects;
2834

@@ -62,6 +68,9 @@ public Model apply(ModelInfo ModelInfo) {
6268
private final Long lastModifiedTime;
6369
private final Long expirationTime;
6470
private final Labels labels;
71+
private final ImmutableList<TrainingRun> trainingRunList;
72+
private final ImmutableList<StandardSqlField> featureColumnList;
73+
private final ImmutableList<StandardSqlField> labelColumnList;
6574

6675
/** A builder for {@code ModelInfo} objects. */
6776
public abstract static class Builder {
@@ -97,6 +106,12 @@ public abstract static class Builder {
97106

98107
abstract Builder setLastModifiedTime(Long lastModifiedTime);
99108

109+
abstract Builder setTrainingRuns(List<TrainingRun> trainingRunList);
110+
111+
abstract Builder setLabelColumns(List<StandardSqlField> labelColumnList);
112+
113+
abstract Builder setFeatureColumns(List<StandardSqlField> featureColumnList);
114+
100115
/** Creates a {@code ModelInfo} object. */
101116
public abstract ModelInfo build();
102117
}
@@ -112,6 +127,9 @@ static class BuilderImpl extends Builder {
112127
private Long lastModifiedTime;
113128
private Long expirationTime;
114129
private Labels labels = Labels.ZERO;
130+
private List<TrainingRun> trainingRunList = Collections.emptyList();
131+
private List<StandardSqlField> labelColumnList = Collections.emptyList();
132+
private List<StandardSqlField> featureColumnList = Collections.emptyList();
115133

116134
BuilderImpl() {}
117135

@@ -124,6 +142,9 @@ static class BuilderImpl extends Builder {
124142
this.creationTime = modelInfo.creationTime;
125143
this.lastModifiedTime = modelInfo.lastModifiedTime;
126144
this.expirationTime = modelInfo.expirationTime;
145+
this.trainingRunList = modelInfo.trainingRunList;
146+
this.labelColumnList = modelInfo.labelColumnList;
147+
this.featureColumnList = modelInfo.featureColumnList;
127148
}
128149

129150
BuilderImpl(Model modelPb) {
@@ -139,6 +160,15 @@ static class BuilderImpl extends Builder {
139160
this.lastModifiedTime = modelPb.getLastModifiedTime();
140161
this.expirationTime = modelPb.getExpirationTime();
141162
this.labels = Labels.fromPb(modelPb.getLabels());
163+
if (modelPb.getTrainingRuns() != null) {
164+
this.trainingRunList = modelPb.getTrainingRuns();
165+
}
166+
if (modelPb.getLabelColumns() != null) {
167+
this.labelColumnList = modelPb.getLabelColumns();
168+
}
169+
if (modelPb.getFeatureColumns() != null) {
170+
this.featureColumnList = modelPb.getFeatureColumns();
171+
}
142172
}
143173

144174
@Override
@@ -195,6 +225,24 @@ public Builder setLabels(Map<String, String> labels) {
195225
return this;
196226
}
197227

228+
@Override
229+
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
230+
this.trainingRunList = checkNotNull(trainingRunList);
231+
return this;
232+
}
233+
234+
@Override
235+
Builder setLabelColumns(List<StandardSqlField> labelColumnList) {
236+
this.labelColumnList = checkNotNull(labelColumnList);
237+
return this;
238+
}
239+
240+
@Override
241+
Builder setFeatureColumns(List<StandardSqlField> featureColumnList) {
242+
this.featureColumnList = checkNotNull(featureColumnList);
243+
return this;
244+
}
245+
198246
@Override
199247
public ModelInfo build() {
200248
return new ModelInfo(this);
@@ -211,6 +259,9 @@ public ModelInfo build() {
211259
this.lastModifiedTime = builder.lastModifiedTime;
212260
this.expirationTime = builder.expirationTime;
213261
this.labels = builder.labels;
262+
this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList);
263+
this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList);
264+
this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList);
214265
}
215266

216267
/** Returns the hash of the model resource. */
@@ -261,6 +312,24 @@ public Map<String, String> getLabels() {
261312
return labels.userMap();
262313
}
263314

315+
/** Returns metadata about each training run iteration. */
316+
@BetaApi
317+
public ImmutableList<TrainingRun> getTrainingRuns() {
318+
return trainingRunList;
319+
}
320+
321+
/** Returns information about the label columns for this model. */
322+
@BetaApi
323+
public ImmutableList<StandardSqlField> getLabelColumns() {
324+
return labelColumnList;
325+
}
326+
327+
/** Returns information about the feature columns for this model. */
328+
@BetaApi
329+
public ImmutableList<StandardSqlField> getFeatureColumns() {
330+
return featureColumnList;
331+
}
332+
264333
public Builder toBuilder() {
265334
return new BuilderImpl(this);
266335
}
@@ -277,6 +346,9 @@ public String toString() {
277346
.add("lastModifiedTime", lastModifiedTime)
278347
.add("expirationTime", expirationTime)
279348
.add("labels", labels)
349+
.add("trainingRuns", trainingRunList)
350+
.add("labelColumns", labelColumnList)
351+
.add("featureColumns", featureColumnList)
280352
.toString();
281353
}
282354

@@ -321,6 +393,9 @@ Model toPb() {
321393
modelPb.setLastModifiedTime(lastModifiedTime);
322394
modelPb.setExpirationTime(expirationTime);
323395
modelPb.setLabels(labels.toPb());
396+
modelPb.setTrainingRuns(trainingRunList);
397+
modelPb.setLabelColumns(labelColumnList);
398+
modelPb.setFeatureColumns(featureColumnList);
324399
return modelPb;
325400
}
326401

branches/autosynth-websecurityscanner/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import static org.junit.Assert.assertEquals;
1919
import static org.junit.Assert.assertNull;
2020

21+
import com.google.api.services.bigquery.model.TrainingOptions;
22+
import com.google.api.services.bigquery.model.TrainingRun;
23+
import java.util.Arrays;
24+
import java.util.List;
2125
import org.junit.Test;
2226

2327
public class ModelInfoTest {
@@ -30,6 +34,12 @@ public class ModelInfoTest {
3034
private static final String DESCRIPTION = "description";
3135
private static final String FRIENDLY_NAME = "friendlyname";
3236

37+
private static final TrainingOptions TRAINING_OPTIONS =
38+
new TrainingOptions().setDataSplitColumn("foo").setEarlyStop(true).setLossType("bar");
39+
private static final TrainingRun TRAINING_RUN =
40+
new TrainingRun().setTrainingOptions(TRAINING_OPTIONS);
41+
private static final List<TrainingRun> TRAINING_RUN_LIST = Arrays.asList(TRAINING_RUN);
42+
3343
private static final ModelInfo MODEL_INFO =
3444
ModelInfo.newBuilder(MODEL_ID)
3545
.setEtag(ETAG)
@@ -38,6 +48,7 @@ public class ModelInfoTest {
3848
.setLastModifiedTime(LAST_MODIFIED_TIME)
3949
.setDescription(DESCRIPTION)
4050
.setFriendlyName(FRIENDLY_NAME)
51+
.setTrainingRuns(TRAINING_RUN_LIST)
4152
.build();
4253

4354
@Test
@@ -59,6 +70,7 @@ public void testBuilder() {
5970
assertEquals(EXPIRATION_TIME, MODEL_INFO.getExpirationTime());
6071
assertEquals(DESCRIPTION, MODEL_INFO.getDescription());
6172
assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName());
73+
assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions());
6274
}
6375

6476
@Test
@@ -71,6 +83,9 @@ public void testOf() {
7183
assertNull(modelInfo.getExpirationTime());
7284
assertNull(modelInfo.getDescription());
7385
assertNull(modelInfo.getFriendlyName());
86+
assertEquals(modelInfo.getTrainingRuns().isEmpty(), true);
87+
assertEquals(modelInfo.getLabelColumns().isEmpty(), true);
88+
assertEquals(modelInfo.getFeatureColumns().isEmpty(), true);
7489
}
7590

7691
@Test
@@ -94,5 +109,8 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) {
94109
assertEquals(expected.getFriendlyName(), value.getFriendlyName());
95110
assertEquals(expected.getLabels(), value.getLabels());
96111
assertEquals(expected.hashCode(), value.hashCode());
112+
assertEquals(expected.getTrainingRuns(), value.getTrainingRuns());
113+
assertEquals(expected.getLabelColumns(), value.getLabelColumns());
114+
assertEquals(expected.getFeatureColumns(), value.getFeatureColumns());
97115
}
98116
}

branches/autosynth-websecurityscanner/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,11 @@ public void testModelLifecycle() throws InterruptedException {
10831083
Model model = bigquery.getModel(modelId);
10841084
assertNotNull(model);
10851085
assertEquals(model.getModelType(), "LINEAR_REGRESSION");
1086+
// Compare the extended model metadata.
1087+
assertEquals(model.getFeatureColumns().get(0).getName(), "f1");
1088+
assertEquals(model.getLabelColumns().get(0).getName(), "predicted_label");
1089+
assertEquals(
1090+
model.getTrainingRuns().get(0).getTrainingOptions().getLearnRateStrategy(), "CONSTANT");
10861091

10871092
// Mutate metadata.
10881093
ModelInfo info = model.toBuilder().setDescription("TEST").build();

0 commit comments

Comments
 (0)