44from __future__ import unicode_literals
55
66import numpy as np
7- import time
87import unittest
98
109# Must happen before importing caffe2.python.*
1110import caffe2 .python .fakelowp .init_shared_libs # noqa
1211
13- from hypothesis import given , settings
12+ from hypothesis import given
1413from hypothesis import strategies as st
1514from caffe2 .proto import caffe2_pb2
16- from caffe2 .python import core , workspace , dyndep
15+ from caffe2 .python import core , workspace
1716from caffe2 .python .onnx .onnxifi import onnxifi_caffe2_net
18- from caffe2 .python .onnx .tests .test_utils import TestCase
1917from caffe2 .python .fakelowp .test_utils import print_test_debug_info
2018
2119workspace .GlobalInit (["caffe2" , "--glow_global_fp16=1" ,
2220 "--glow_global_fused_scale_offset_fp16=1" ,
2321 "--glow_global_force_sls_fp16_accum=1" ])
2422
2523
26- class SparseLengthsSumTest (unittest .TestCase ):
24+ class SparseLengthsSum4BitFakeNNPIFp16Test (unittest .TestCase ):
2725 @given (seed = st .integers (0 , 65535 ))
2826 def test_slws_fused_4bit_rowwise_all_same (self , seed ):
2927 np .random .seed (seed )
3028 workspace .ResetWorkspace ()
3129 n = 1
3230 m = 2
3331 data = np .ones ((n , m )).astype (np .float32 ) * 0.2 - 0.1
34-
3532 max_segments = 5
3633 max_segment_length = 100
3734 num_lengths = np .random .randint (1 , max_segments + 1 )
@@ -43,7 +40,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
4340 weights = np .random .uniform (low = - 0.5 , high = 0.5 ,
4441 size = [len (indices )]).astype (np .float32 )
4542 weights = np .ones (len (indices )).astype (np .float32 )
46-
4743 pred_net = caffe2_pb2 .NetDef ()
4844 pred_net .name = "pred"
4945 pred_net .external_input .extend (
@@ -56,7 +52,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
5652 ["Y" ],
5753 )
5854 )
59-
6055 ref_net = caffe2_pb2 .NetDef ()
6156 ref_net .name = "ref"
6257 ref_net .external_input .extend (
@@ -69,7 +64,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
6964 ["Y" ],
7065 )
7166 )
72-
7367 workspace .FeedBlob ("data" , data )
7468 workspace .RunOperatorOnce (
7569 core .CreateOperator (
@@ -78,7 +72,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
7872 ['quantized_data' ]
7973 )
8074 )
81-
8275 print ("quantized" , workspace .FetchBlob ("quantized_data" ))
8376 pred_net_onnxified = onnxifi_caffe2_net (
8477 pred_net ,
@@ -89,24 +82,18 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
8982 adjust_batch = True ,
9083 use_onnx = False
9184 )
92-
9385 num_onnxified_ops = sum (
9486 1 if o .type == "Onnxifi" else 0 for o in pred_net_onnxified .op )
9587 np .testing .assert_equal (num_onnxified_ops , 1 )
96-
9788 workspace .FeedBlob ("indices" , indices )
9889 workspace .FeedBlob ("lengths" , lengths )
9990 workspace .FeedBlob ("weights" , weights )
100-
10191 workspace .CreateNet (pred_net_onnxified )
10292 workspace .CreateNet (ref_net )
103-
10493 workspace .RunNet (pred_net_onnxified .name )
10594 Y_glow = workspace .FetchBlob ('Y' )
106-
10795 workspace .RunNet (ref_net .name )
10896 Y_c2 = workspace .FetchBlob ('Y' )
109-
11097 if not np .allclose (Y_c2 , Y_glow ):
11198 print_test_debug_info (
11299 "slws_fused_4bit_rowwise" ,
@@ -121,33 +108,35 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
121108 "rowwise_diff" : (Y_glow - Y_c2 )[:, 0 ]})
122109 assert (0 )
123110
124- @given (seed = st .integers (0 , 65535 ))
125- def test_slws_fused_4bit_rowwise (self , seed ):
126- np .random .seed (seed )
111+
112+ @given (
113+ seed = st .integers (0 , 65535 ),
114+ num_rows = st .integers (2 , 20 ),
115+ embedding_dim = st .sampled_from ([8 , 12 , 16 , 24 , 32 , 54 , 64 , 128 ]),
116+ batch_size = st .integers (1 , 5 ),
117+ max_weight = st .integers (0 , 100 ),
118+ )
119+ def test_slws_fused_4bit_rowwise (self , seed , num_rows , embedding_dim , batch_size , max_weight ):
127120 workspace .ResetWorkspace ()
121+ np .random .seed (seed )
122+ data = np .random .rand (num_rows , embedding_dim ).astype (np .float32 )
123+ lengths = np .random .choice (np .arange (1 , num_rows ), batch_size ).astype (np .int32 )
128124
129- n = 20000
130- DIM = 6
131- data = (4 * np .random .random_sample ((n , DIM )) + 1 ).astype (np .float32 )
125+ indices = []
126+ for length in lengths :
127+ indices .extend (np .random .choice (np .arange (1 , num_rows ), length ))
128+ indices = np .asarray (indices ).astype (np .int64 )
132129
133- max_segments = 200
134- max_segment_length = 200
135- num_lengths = np .random .randint (0 , max_segments + 1 )
136- # number of segments to run
137- lengths = np .random .randint (2 , max_segment_length + 1 , size = num_lengths ).astype (
138- np .int32
139- )
140- num_indices = np .sum (lengths )
141- indices = np .random .randint (low = 0 , high = n , size = num_indices , dtype = np .int64 )
142- weights = np .random .uniform (low = 0.01 , high = 0.5 , size = [len (indices )]).astype (
143- np .float32
144- )
130+ weights = np .random .uniform (
131+ low = 0 ,
132+ high = max_weight ,
133+ size = [len (indices )]
134+ ).astype (np .float32 )
145135
146136 pred_net = caffe2_pb2 .NetDef ()
147137 pred_net .name = "pred"
148138 pred_net .external_input .extend (
149- ["quantized_data" , "weights" , "indices" , "lengths" ]
150- )
139+ ["quantized_data" , "weights" , "indices" , "lengths" ])
151140 pred_net .external_output .append ("Y" )
152141 pred_net .op .add ().CopyFrom (
153142 core .CreateOperator (
@@ -160,8 +149,7 @@ def test_slws_fused_4bit_rowwise(self, seed):
160149 ref_net = caffe2_pb2 .NetDef ()
161150 ref_net .name = "ref"
162151 ref_net .external_input .extend (
163- ["quantized_data" , "weights" , "indices" , "lengths" ]
164- )
152+ ["quantized_data" , "weights" , "indices" , "lengths" ])
165153 ref_net .external_output .append ("Y" )
166154 ref_net .op .add ().CopyFrom (
167155 core .CreateOperator (
@@ -174,49 +162,52 @@ def test_slws_fused_4bit_rowwise(self, seed):
174162 workspace .FeedBlob ("data" , data )
175163 workspace .RunOperatorOnce (
176164 core .CreateOperator (
177- "FloatToFused4BitRowwiseQuantized" , ["data" ], ["quantized_data" ]
165+ "FloatToFused4BitRowwiseQuantized" ,
166+ ["data" ],
167+ ["quantized_data" ]
178168 )
179169 )
180- onnxified_net = onnxifi_caffe2_net (
170+
171+ pred_net_onnxified = onnxifi_caffe2_net (
181172 pred_net ,
182173 {},
183- max_batch_size = max_segments ,
184- max_seq_size = max_segments * max_segment_length ,
174+ max_batch_size = batch_size ,
175+ max_seq_size = batch_size * np . max ( lengths ) ,
185176 debug = True ,
186177 adjust_batch = True ,
187- use_onnx = False ,
178+ use_onnx = False
188179 )
180+
181+ num_onnxified_ops = sum (
182+ 1 if o .type == "Onnxifi" else 0 for o in pred_net_onnxified .op )
183+ np .testing .assert_equal (num_onnxified_ops , 1 )
184+
189185 workspace .FeedBlob ("indices" , indices )
190186 workspace .FeedBlob ("lengths" , lengths )
191187 workspace .FeedBlob ("weights" , weights )
192188
193- workspace .CreateNet (onnxified_net )
189+ workspace .CreateNet (pred_net_onnxified )
194190 workspace .CreateNet (ref_net )
195191
196- workspace .RunNet (onnxified_net .name )
197- Y_glow = workspace .FetchBlob ("Y" )
192+ workspace .RunNet (pred_net_onnxified .name )
193+ Y_glow = workspace .FetchBlob ('Y' )
198194
199195 workspace .RunNet (ref_net .name )
200- Y_ref = workspace .FetchBlob ("Y" )
196+ Y_c2 = workspace .FetchBlob ('Y' )
201197
202- diff = np .abs ((Y_ref - Y_glow ) / (Y_ref + 1e-8 ))
203- max_err = np .max (diff , axis = 1 )
204- num_offenders = (max_err > 0 ).sum ()
205- if num_offenders > 0 :
198+ if not np .allclose (Y_c2 , Y_glow ):
206199 print_test_debug_info (
207- "slws_fused_4bit" ,
208- {
209- "indices" : indices ,
210- "data" : data .shape ,
211- "lengths" : lengths ,
212- "weights" : weights ,
213- "Y_glow" : Y_glow ,
214- "Y_ref" : Y_ref ,
215- "diff" : diff ,
216- "rowwise_diff" : np .max (diff , axis = 1 ),
217- },
218- )
219- assert 0
200+ "slws_fused_4bit_rowwise" ,
201+ {"seed" : seed ,
202+ "indices" : indices ,
203+ "data" : data ,
204+ "lengths" : lengths ,
205+ "weights" : weights ,
206+ "Y_c2" : Y_c2 ,
207+ "Y_glow" : Y_glow ,
208+ "diff" : Y_glow - Y_c2 ,
209+ "rowwise_diff" : (Y_glow - Y_c2 )[:, 0 ]})
210+ assert (0 )
220211
221212if __name__ == '__main__' :
222213 unittest .main ()
0 commit comments