@@ -449,8 +449,7 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon
449449 """
450450 # post init for bitblas backend.
451451 device_to_buffers_size = {}
452- # exllama
453- model_uses_exllama = False
452+
454453 model_uses_qbits = False
455454
456455 # exllamav2
@@ -467,54 +466,6 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon
467466 scratch_fixed = submodule .scratch_space_fixed ()
468467 fixed_bytes [device ] = max (scratch_fixed , fixed_bytes .get (device , 0 ))
469468
470- if model_uses_exllama :
471- # To be honest this is quite ugly, not proud of this.
472- from gptqmodel_exllama_kernels import prepare_buffers , set_tuning_params
473-
474- device_to_buffers = {}
475-
476- if use_act_order :
477- if max_input_length is None :
478- max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
479- else :
480- max_input_len = max_input_length
481- else :
482- if max_input_length is not None :
483- logger .info (
484- "Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored."
485- )
486- max_input_len = 1
487-
488- for device , buffers_size in device_to_buffers_size .items ():
489- # The temp_state buffer is required to reorder X in the act-order case.
490- # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
491- device_to_buffers [device ] = {
492- "temp_state" : torch .zeros (
493- (max_input_len , buffers_size ["max_inner_outer_dim" ]),
494- dtype = torch .float16 ,
495- device = device ,
496- ),
497- "temp_dq" : torch .zeros (
498- (1 , buffers_size ["max_dq_buffer_size" ]),
499- dtype = torch .float16 ,
500- device = device ,
501- ),
502- "max_dq_buffer_size" : buffers_size ["max_dq_buffer_size" ],
503- "max_inner_outer_dim" : buffers_size ["max_inner_outer_dim" ],
504- }
505-
506- # Buffers need to be persistent to avoid any bug.
507- model .device_to_buffers = device_to_buffers
508-
509- for device , buffers in model .device_to_buffers .items ():
510- prepare_buffers (device , buffers ["temp_state" ], buffers ["temp_dq" ])
511-
512- # Using the default from exllama repo here.
513- matmul_recons_thd = 8
514- matmul_fused_remap = False
515- matmul_no_half2 = False
516- set_tuning_params (matmul_recons_thd , matmul_fused_remap , matmul_no_half2 )
517-
518469 if model_uses_exllamav2 :
519470 from ..nn_modules .qlinear .qlinear_exllamav2 import ExLlamaV2DeviceTensors
520471
0 commit comments