Skip to content
This repository was archived by the owner on Feb 25, 2025. It is now read-only.

Commit 5d5a0c0

Browse files
committed
Check in token-level language model via tflite ffi
Local test run after running `gclient sync`: ariaye@ariaye1:~/sdk/sdk$ dart pkg/analysis_server/test/services/completion/dart/language_model_test.dart 00:00 +0: calculates lookback INFO: Initialized TensorFlow Lite runtime. 00:00 +1: predict with defaults 00:01 +2: predict with confidence scores 00:03 +3: predict when no previous tokens 00:04 +4: All tests passed! Change-Id: I4181bea09cf8fec74d03bba4f83cd26dac818f30 Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/109662 Reviewed-by: Brian Wilkerson <[email protected]>
1 parent b9ab8ef commit 5d5a0c0

File tree

6 files changed

+204
-0
lines changed

6 files changed

+204
-0
lines changed

.packages

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ test_process:third_party/pkg/test_process/lib
100100
test_reflective_loader:third_party/pkg/test_reflective_loader/lib
101101
test_runner:pkg/test_runner/lib
102102
testing:pkg/testing/lib
103+
tflite_native:third_party/pkg/tflite_native/lib
103104
typed_data:third_party/pkg/typed_data/lib
104105
unittest:third_party/pkg/unittest/lib
105106
usage:third_party/pkg/usage/lib

DEPS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ vars = {
135135
"term_glyph_tag": "1.0.1",
136136
"test_reflective_loader_tag": "0.1.8",
137137
"test_tag": "test-v1.6.4",
138+
"tflite_native_rev": "712b8a93fbb4caf83ffed37f154da88c2a517a91",
138139
"typed_data_tag": "1.1.6",
139140
"unittest_rev": "2b8375bc98bb9dc81c539c91aaea6adce12e1072",
140141
"usage_tag": "3.4.0",
@@ -360,6 +361,8 @@ deps = {
360361
Var("dart_git") + "term_glyph.git" + "@" + Var("term_glyph_tag"),
361362
Var("dart_root") + "/third_party/pkg/test":
362363
Var("dart_git") + "test.git" + "@" + Var("test_tag"),
364+
Var("dart_root") + "/third_party/pkg/tflite_native":
365+
Var("dart_git") + "tflite_native.git" + "@" + Var("tflite_native_rev"),
363366
Var("dart_root") + "/third_party/pkg/test_descriptor":
364367
Var("dart_git") + "test_descriptor.git" + "@" + Var("test_descriptor_tag"),
365368
Var("dart_root") + "/third_party/pkg/test_process":
@@ -399,6 +402,16 @@ deps = {
399402
"dep_type": "cipd",
400403
},
401404

405+
Var("dart_root") + "/pkg/analysis_server/language_model": {
406+
"packages": [
407+
{
408+
"package": "dart/language_model",
409+
"version": "KB68QHR1SKtopACaf3TFcu9MusRbwWqs0L1m_urGLL4C",
410+
}
411+
],
412+
"dep_type": "cipd",
413+
},
414+
402415
Var("dart_root") + "/buildtools": {
403416
"packages": [
404417
{
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file
2+
// for details. All rights reserved. Use of this source code is governed by a
3+
// BSD-style license that can be found in the LICENSE file.
4+
5+
import 'dart:io';
6+
import 'dart:convert';
7+
import 'dart:typed_data';
8+
9+
import 'package:path/path.dart' as path;
10+
import 'package:quiver/check.dart';
11+
import 'package:tflite_native/tflite.dart' as tfl;
12+
13+
/// Interface to TensorFlow-based Dart language model for next-token prediction.
14+
class LanguageModel {
15+
static const _defaultCompletions = 100;
16+
17+
final tfl.Interpreter _interpreter;
18+
final Map<String, int> _word2idx;
19+
final Map<int, String> _idx2word;
20+
final int _lookback;
21+
22+
LanguageModel._(
23+
this._interpreter, this._word2idx, this._idx2word, this._lookback);
24+
25+
/// Number of previous tokens to look at during predictions.
26+
int get lookback => _lookback;
27+
28+
/// Number of completion results to return during predictions.
29+
int get completions => _defaultCompletions;
30+
31+
/// Load model from directory.
32+
factory LanguageModel.load(String directory) {
33+
// Load model.
34+
final interpreter =
35+
tfl.Interpreter.fromFile(path.join(directory, 'model.tflite'));
36+
interpreter.allocateTensors();
37+
38+
// Load word2idx mapping for input.
39+
final word2idx = json
40+
.decode(File(path.join(directory, 'word2idx.json')).readAsStringSync())
41+
.cast<String, int>();
42+
43+
// Load idx2word mapping for output.
44+
final idx2word = json
45+
.decode(File(path.join(directory, 'idx2word.json')).readAsStringSync())
46+
.map<int, String>((k, v) => MapEntry<int, String>(int.parse(k), v));
47+
48+
// Get lookback size from model input tensor shape.
49+
final tensorShape = interpreter.getInputTensors().single.shape;
50+
checkArgument(tensorShape.length == 2 && tensorShape.first == 1,
51+
message:
52+
'tensor shape $tensorShape does not match the expected [1, X]');
53+
final lookback = tensorShape.last;
54+
55+
return LanguageModel._(interpreter, word2idx, idx2word, lookback);
56+
}
57+
58+
/// Tear down the interpreter.
59+
void close() {
60+
_interpreter.delete();
61+
}
62+
63+
/// Predicts the next token to follow a list of precedent tokens
64+
///
65+
/// Returns a list of tokens, sorted by most probable first.
66+
List<String> predict(Iterable<String> tokens) =>
67+
predictWithScores(tokens).keys.toList();
68+
69+
/// Predicts the next token with confidence scores.
70+
///
71+
/// Returns an ordered map of tokens to scores, sorted by most probable first.
72+
Map<String, double> predictWithScores(Iterable<String> tokens) {
73+
final tensorIn = _interpreter.getInputTensors().single;
74+
tensorIn.data = _transformInput(tokens);
75+
_interpreter.invoke();
76+
final tensorOut = _interpreter.getOutputTensors().single;
77+
return _transformOutput(tensorOut.data);
78+
}
79+
80+
/// Transforms tokens to data bytes that can be used as interpreter input.
81+
List<int> _transformInput(Iterable<String> tokens) {
82+
// Replace out of vocabulary tokens.
83+
final sanitizedTokens = tokens
84+
.map((token) => _word2idx.containsKey(token) ? token : '<unknown>');
85+
86+
// Get indexes (as floats).
87+
final indexes = Float32List(lookback)
88+
..setAll(0, sanitizedTokens.map((token) => _word2idx[token].toDouble()));
89+
90+
// Get bytes
91+
return Uint8List.view(indexes.buffer);
92+
}
93+
94+
/// Transforms interpreter output data to map of tokens to scores.
95+
Map<String, double> _transformOutput(List<int> databytes) {
96+
// Get bytes.
97+
final bytes = Uint8List.fromList(databytes);
98+
99+
// Get scores (as floats)
100+
final probabilities = Float32List.view(bytes.buffer);
101+
102+
// Get indexes with scores, sorted by scores (descending)
103+
final entries = probabilities.asMap().entries.toList()
104+
..sort((a, b) => b.value.compareTo(a.value));
105+
106+
// Get tokens with scores, limiting the length.
107+
return Map.fromEntries(entries.sublist(0, completions))
108+
.map((k, v) => MapEntry(_idx2word[k], v));
109+
}
110+
}

pkg/analysis_server/pubspec.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
source_span: any
2020
package_config: any
2121
path: any
22+
tflite_native: any
2223
watcher: any
2324
yaml: any
2425

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file
2+
// for details. All rights reserved. Use of this source code is governed by a
3+
// BSD-style license that can be found in the LICENSE file.
4+
5+
import 'dart:ffi';
6+
import 'dart:io';
7+
8+
import 'package:analysis_server/src/services/completion/dart/language_model.dart';
9+
import 'package:path/path.dart' as path;
10+
import 'package:test/test.dart';
11+
import 'package:test_reflective_loader/test_reflective_loader.dart';
12+
13+
final directory =
14+
Platform.script.resolve('../../../../language_model/lexeme').path;
15+
const expectedLookback = 100;
16+
17+
void main() {
18+
if (Platform.isWindows || sizeOf<IntPtr>() == 4) {
19+
// We don't yet support running tflite on Windows or 32-bit systems.
20+
return;
21+
}
22+
23+
LanguageModel model;
24+
25+
setUp(() {
26+
model = LanguageModel.load(directory);
27+
});
28+
29+
tearDown(() {
30+
model.close();
31+
});
32+
33+
test('calculates lookback', () {
34+
expect(model.lookback, expectedLookback);
35+
});
36+
37+
test('predict with defaults', () {
38+
final tokens =
39+
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
40+
final suggestions = model.predict(tokens);
41+
expect(suggestions, hasLength(model.completions));
42+
expect(suggestions.first, 'length');
43+
});
44+
45+
test('predict with confidence scores', () {
46+
final tokens =
47+
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
48+
final suggestions = model.predictWithScores(tokens);
49+
final best = suggestions.entries.first;
50+
expect(best.key, 'length');
51+
expect(best.value, greaterThan(0.8));
52+
});
53+
54+
test('predict when no previous tokens', () {
55+
final tokens = <String>[];
56+
final suggestions = model.predict(tokens);
57+
expect(suggestions, hasLength(model.completions));
58+
expect(suggestions.first, isNotEmpty);
59+
});
60+
61+
test('load fail', () {
62+
try {
63+
LanguageModel.load('doesnotexist');
64+
fail('Failure to load language model should throw an exception');
65+
} catch (e) {
66+
expect(
67+
e.toString(), equals('Invalid argument(s): Unable to create model.'));
68+
}
69+
});
70+
}
71+
72+
/// Tokenizes the input string.
73+
///
74+
/// The input is split by word boundaries and trimmed of whitespace.
75+
List<String> tokenize(String input) =>
76+
input.split(RegExp(r'\b|\s')).map((t) => t.trim()).toList()
77+
..removeWhere((t) => t.isEmpty);

pkg/analysis_server/test/services/completion/dart/test_all.dart

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import 'imported_reference_contributor_test.dart' as imported_ref_test;
1313
import 'inherited_reference_contributor_test.dart' as inherited_ref_test;
1414
import 'keyword_contributor_test.dart' as keyword_test;
1515
import 'label_contributor_test.dart' as label_contributor_test;
16+
import 'language_model_test.dart' as language_model_test;
1617
import 'library_member_contributor_test.dart' as library_member_test;
1718
import 'library_prefix_contributor_test.dart' as library_prefix_test;
1819
import 'local_constructor_contributor_test.dart' as local_constructor_test;
@@ -37,6 +38,7 @@ main() {
3738
inherited_ref_test.main();
3839
keyword_test.main();
3940
label_contributor_test.main();
41+
language_model_test.main();
4042
library_member_test.main();
4143
library_prefix_test.main();
4244
local_constructor_test.main();

0 commit comments

Comments
 (0)