From 646d0cdb8695f6013b8319c78e263b1c8f49cbec Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 1 Jun 2022 11:07:29 +0800 Subject: [PATCH 01/22] feat(default.py): add support ViT opr --- ppq/executor/op/torch/default.py | 101 +++++++++++++++++++++-------- ppq/executor/torch.py | 1 + ppq/quantization/quantizer/base.py | 4 +- 3 files changed, 77 insertions(+), 29 deletions(-) diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index e97099a0..d1e9da7b 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -333,6 +333,35 @@ def Mul_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCont return multiplicand * multiplier +def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: + if len(values) != 11: + raise NotImplementedError('Not implement simplified MultiHeadAttention') + + q,k,v,q_w,q_b,k_w,k_b,v_w,v_b,o_w,o_b = values + embed_dim = op.attributes.get('embed_dim') + num_heads = op.attributes.get('num_heads') + + if embed_dim is None or num_heads is None: + raise ValueError('Cannot fetch embed_dim or num_heads') + + # setup parameters + batch_size = q.shape[0] + head_dim = embed_dim // num_heads + scale = head_dim ** -0.5 + + q = F.linear(q, q_w, q_b) + k = F.linear(k, k_w, k_b) + v = F.linear(v, v_w, v_b) + + energy = (q @ k.transpose(-2, -1)) * scale + attn = energy.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim) + x = F.linear(x, o_w, o_b) + + return x + + def Add_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: """Performs element-wise binary addition (with Numpy-style broadcasting support). @@ -786,6 +815,9 @@ def GatherND_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken reshaped_output = output.reshape(*shape_i, *shape_j, *shape_k) return output +def Gelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: + [input_value] = values + return F.gelu(input_value) def Greater_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: input_a, input_b = values @@ -1436,7 +1468,7 @@ def Split_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCo split = op.attributes.get('split', 0) [input_value] = values if 'split' not in op.attributes: - split = input_value.shape[axis] // len(op.outputs) + split = input_value.shape[axis] // len(op.outputs) outputs = torch.split(input_value, split, axis) return outputs @@ -1525,6 +1557,18 @@ def LeakyRelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke return output +def LayerNorm_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs): + if len(values) != 3: + raise ValueError('Unsupported LayerNorm without affine') + + input_data, weight, bias = values + eps = op.attributes.get('epsilon', 1e-5) + normalized_shape = weight.shape + + output = F.layer_norm(input_data, normalized_shape, weight, bias, eps) + return output + + def Pad_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs): mode = op.attributes.get('mode', 'constant') input_data = values[0] @@ -2118,20 +2162,20 @@ def Identity_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken return values[0] def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: - """ - Produces a one-hot tensor based on inputs. The locations represented by the index values in the 'indices' - input tensor will have 'on_value' and the other locations will have 'off_value' in the output tensor, - - where 'on_value' and 'off_value' are specified as part of required input argument 'values', - which is a two-element tensor of format [off_value, on_value]. - - The rank of the output tensor will be one greater than the rank of the input tensor. - The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'. - If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension, - i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'. - - The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices' - input tensor with values outside the range [-depth, depth-1] will result in one-hot representation + """Produces a one-hot tensor based on inputs. The locations represented by + the index values in the 'indices' input tensor will have 'on_value' and the + other locations will have 'off_value' in the output tensor, + + where 'on_value' and 'off_value' are specified as part of required input argument 'values', + which is a two-element tensor of format [off_value, on_value]. + + The rank of the output tensor will be one greater than the rank of the input tensor. + The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'. + If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension, + i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'. + + The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices' + input tensor with values outside the range [-depth, depth-1] will result in one-hot representation with all 'off_value' values in the output tensor. when axis = 0: @@ -2144,30 +2188,30 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC Attributes axis : int (default is -1) - (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that - the additional dimension will be inserted as the innermost/last dimension in the output tensor. + (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that + the additional dimension will be inserted as the innermost/last dimension in the output tensor. Negative value means counting dimensions from the back. Accepted range is [-r-1, r] where r = rank(indices). - + Inputs indices (non-differentiable) : T1 Input tensor containing indices. Any entries in the 'indices' input tensor with values outside the range [-depth, depth-1] - will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type, + will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type, the values will be casted to int64 before use. - + depth (non-differentiable) : T2 - Scalar specifying the number of classes in one-hot tensor. + Scalar specifying the number of classes in one-hot tensor. This is also the size of the one-hot dimension (specified by 'axis' attribute) added on in the output tensor. - The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1]. + The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1]. In case 'depth' is of non-integer type, it will be casted to int64 before use. values (non-differentiable) : T3 - Rank 1 tensor containing exactly two elements, - in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor, + Rank 1 tensor containing exactly two elements, + in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor, and 'off_value' is the value used for filling locations other than those specified in 'indices' input tensor. Outputs output (non-differentiable) : T3 - Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1. + Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1. The data type for the elements of the output tensor is the same as the type of input 'values' is used. """ # implementation from https://github.com/ToriML/onnx2pytorch/blob/master/onnx2pytorch/operations/onehot.py @@ -2187,10 +2231,10 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC order.insert(axis, -1) out = out.permute(order) return out - + def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: """ - Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is, + Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is, y = 1/x, is applied to the tensor elementwise. Version @@ -2231,11 +2275,13 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack 'Gather': Gather_forward, 'GatherElements': Gather_forward, 'GatherND': GatherND_forward, + 'Gelu': Gelu_forward, 'Gemm': Gemm_forward, 'grid_sampler': Grid_sampler_forward, 'GlobalAveragePool': AveragePool_forward, 'GlobalMaxPool': MaxPool2d_forward, 'Greater': Greater_forward, + 'LayerNorm': LayerNorm_forward, 'LeakyRelu': LeakyRelu_forward, 'Less': Less_forward, 'MatMul': MatMul_forward, @@ -2243,6 +2289,7 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack 'MaxPool': MaxPool2d_forward, 'Min': Eltwise_forward, 'Mul': Mul_forward, + 'MultiHeadAttention': MultiHeadAttention_forward, 'NonMaxSuppression': _NMS_forward, 'NonZero': NonZero_forward, 'Not': Not_forward, diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py index 16cd31f3..7687c14c 100644 --- a/ppq/executor/torch.py +++ b/ppq/executor/torch.py @@ -327,6 +327,7 @@ def __forward( result_collector[output_names.index(name)] = inputs[name] for operation in executing_order[: last_idx]: + print(str(operation)) try: assert isinstance(operation, Operation), 'Oops, seems you got something weird in your graph' assert isinstance(operation.platform, TargetPlatform), ( diff --git a/ppq/quantization/quantizer/base.py b/ppq/quantization/quantizer/base.py index 0556f49d..491f6b45 100644 --- a/ppq/quantization/quantizer/base.py +++ b/ppq/quantization/quantizer/base.py @@ -110,12 +110,12 @@ def quantize_operations( operation_platforms[op_name] = self.target_platform else: operation_platforms[op_name] = self.default_platform - # maunnl override. + # manual override. if op_name in operation_platforms: operation.platform = operation_platforms[op_name] # build operation_quantization_configs - # every quantable op MUST have a quantization config + # every quantizable op MUST have a quantization config # if operation.type is listed in quantable_operation_types while a operation_quantization_configs is given # it will override the setting of quantable_operation_types for op_name, operation in self._graph.operations.items(): From c894bf704343b068bce87d3552ac27627ce84f6f Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 1 Jun 2022 11:09:20 +0800 Subject: [PATCH 02/22] docs(torch.py): remove useless --- ppq/executor/torch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py index 7687c14c..16cd31f3 100644 --- a/ppq/executor/torch.py +++ b/ppq/executor/torch.py @@ -327,7 +327,6 @@ def __forward( result_collector[output_names.index(name)] = inputs[name] for operation in executing_order[: last_idx]: - print(str(operation)) try: assert isinstance(operation, Operation), 'Oops, seems you got something weird in your graph' assert isinstance(operation.platform, TargetPlatform), ( From 5373008f22ddc6936251320b007545db5ec84525 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Sun, 5 Jun 2022 21:48:04 +0800 Subject: [PATCH 03/22] feat(VIT): support new opr fix layernorm user percentile update update remove useless concat config --- ppq/api/setting.py | 2 +- ppq/core/quant.py | 2 +- ppq/executor/torch.py | 9 +++- ppq/parser/ncnn_exporter.py | 2 + ppq/quantization/observer/range.py | 5 ++ ppq/quantization/optim/parameters.py | 4 +- ppq/quantization/quantizer/NCNNQuantizer.py | 58 ++++++++++++++++++--- ppq/quantization/quantizer/base.py | 5 ++ 8 files changed, 75 insertions(+), 12 deletions(-) diff --git a/ppq/api/setting.py b/ppq/api/setting.py index dbb393a4..f7cf0724 100644 --- a/ppq/api/setting.py +++ b/ppq/api/setting.py @@ -164,7 +164,7 @@ class ActivationQuantizationSetting(): def __init__(self) -> None: # 激活值校准算法,不区分大小写,可以选择 minmax, kl, percentile, MSE # activation calibration method - self.calib_algorithm = 'percentile' + self.calib_algorithm = 'minmax' # 执行逐层激活值校准,延长执行时间,提升精度 # whether to calibrate activation per - layer. diff --git a/ppq/core/quant.py b/ppq/core/quant.py index a9681aed..d604c140 100644 --- a/ppq/core/quant.py +++ b/ppq/core/quant.py @@ -658,7 +658,7 @@ class ChannelwiseTensorQuantizationConfig(TensorQuantizationConfig): """ChannelwiseTensorQuantizationConfig is a special case for tensor quantization configuration. - Comparing with per-tensor quantization configuration, pre-channel quantization has a property + Comparing with per-tensor quantization configuration, per-channel quantization has a property "channel_axis" to indicate a channel axis where quantization takes effects. Along this axis, all tensor values will be quantized with a sharing scale and offset, diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py index 16cd31f3..7d145075 100644 --- a/ppq/executor/torch.py +++ b/ppq/executor/torch.py @@ -367,6 +367,9 @@ def __forward( outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs] fp_outputs = outputs + for out in fp_outputs: + print("{} output shape {}".format(operation.name, out.shape)) + # quantize all result if is necessary if isinstance(operation, QuantableOperation): output_configs = [_ for _ in operation.config.output_quantization_config] @@ -390,8 +393,10 @@ def __forward( if output_var.name in output_names: result_collector[output_names.index(output_var.name)] = outputs[output_idx] - except Exception as _: - raise RuntimeError(f'Error happens when dealing with operation {str(operation)}') + except Exception as e: + import pdb + pdb.set_trace() + raise RuntimeError(f'Error happens when dealing with operation {str(operation)}, {str(e)}') # remove useless value(runtime clear). visited_op.append(operation) diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 7285759a..5bf5210c 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -13,6 +13,8 @@ class NCNNExporter(GraphExporter): def export_quantization_config(self, config_path: str, graph: BaseGraph): fd = open(config_path, 'w+') topo_order = graph.topological_sort() + import pdb + pdb.set_trace() for op in topo_order: if op.is_computing_op and isinstance(op, QuantableOperation): fd.write(f'{op.name}_param_0 ') diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index 53d085c6..42f00a43 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -261,7 +261,10 @@ def observe(self, value: torch.Tensor): else: self._percentile_collector.append(CUDA.Quantile(value, self._percentile).view(1, -1)) elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL): + import pdb + pdb.set_trace() raise PermissionError('Percentile observer can not deal with per channel quantization.') + assert isinstance(self._quant_cfg, ChannelwiseTensorQuantizationConfig), ( 'Your quantization config has PER_CHANNEL while it is not a ' 'ChannelwiseTensorQuantizationConfig instance.') @@ -289,6 +292,8 @@ def render_quantization_config(self): self._quant_cfg.offset = torch.tensor([offset], dtype=torch.float32, device=device).squeeze(0) self._quant_cfg.state = QuantizationStates.ACTIVATED elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL): + import pdb + pdb.set_trace() raise PermissionError('Percentile observer can not deal with per channel quantization.') if len(self._percentile_maxs) == 0: raise ValueError('Can not render quantization config yet, Observer data collator is empty. ' \ diff --git a/ppq/quantization/optim/parameters.py b/ppq/quantization/optim/parameters.py index 4b45b929..1888124d 100644 --- a/ppq/quantization/optim/parameters.py +++ b/ppq/quantization/optim/parameters.py @@ -125,8 +125,10 @@ def optimize( state_records[config] = config.state config.state = QuantizationStates.DEQUANTIZED elif self._method is not None: + config.observer_algorithm = 'Minmax' + # pass # override quantizer's setting if necessary - config.observer_algorithm = self._method + # config.observer_algorithm = self._method observer = OperationObserver( operation=executor._graph.operations[op_name], diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 91c5cf0a..95eb3e53 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -36,10 +36,10 @@ def init_quantize_config( policy=self.quantize_policy, rounding=self.rounding_policy, operation_meta=operation.meta_data, num_of_bits=self._num_of_bits, quant_max=self._quant_max, quant_min=self._quant_min, - observer_algorithm='percentile' + observer_algorithm='Minmax' ) - if operation.type in {'Conv', 'Gemm'}: + if operation.type in {'Conv', 'Gemm', 'Add', 'LayerNorm', 'MultiHeadAttention', 'Gelu', 'Concat'}: assert operation.num_of_input > 0, 'Seems you got a Computing layer with no parameters.' if operation.type == 'Conv': @@ -81,12 +81,55 @@ def init_quantize_config( offsets = None, scales = None, channel_axis = 0 ) base_quant_config.input_quantization_config[1].observer_algorithm = 'Minmax' + + elif operation.type == 'Concat': + base_quant_config.input_quantization_config[1].policy = QuantizationPolicy( + QuantizationProperty.SYMMETRICAL + + QuantizationProperty.LINEAR + + QuantizationProperty.PER_TENSOR + ) + base_quant_config.input_quantization_config[1].observer_algorithm = 'Minmax' + + elif operation.type == 'Add': + # Add 量化输入 + base_quant_config.input_quantization_config[1].policy = QuantizationPolicy( + QuantizationProperty.SYMMETRICAL + + QuantizationProperty.LINEAR + + QuantizationProperty.PER_TENSOR + ) + base_quant_config.input_quantization_config[1].observer_algorithm = 'Minmax' - # if operation has bias - if operation.num_of_input > 2: - bias_config = base_quant_config.input_quantization_config[-1] - bias_config.state = QuantizationStates.FP32 + elif operation.type == 'LayerNorm': + # LayerNorm 输入按 power of 2 量化 + # base_quant_config.input_quantization_config[0].policy = QuantizationPolicy( + # QuantizationProperty.SYMMETRICAL + + # QuantizationProperty.LINEAR + + # QuantizationProperty.PER_TENSOR + # ) + # base_quant_config.input_quantization_config[0].observer_algorithm = 'Minmax' + + inp_config = base_quant_config.input_quantization_config[0] + inp_config.policy = QuantizationPolicy( + QuantizationProperty.SYMMETRICAL + + QuantizationProperty.LINEAR + + QuantizationProperty.PER_CHANNEL + + QuantizationProperty.POWER_OF_2 + ) + base_quant_config.input_quantization_config[0] = \ + ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( + convert_from = inp_config, + offsets = None, scales = None, channel_axis = 0 + ) + base_quant_config.input_quantization_config[0].observer_algorithm = 'Minmax' + + # layerNorm weight 和 bias 都不量化 + wconfig = base_quant_config.input_quantization_config[1] + bconfig = base_quant_config.input_quantization_config[2] + wconfig.state = QuantizationStates.FP32 + bconfig.state = QuantizationStates.FP32 + + # 显式说明输出不量化 base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32 return base_quant_config @@ -100,8 +143,9 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: + # 'Conv', 'Gemm', 'Concat', 'Add', 'LayerNorm', 'MultiHeadAttention', 'Gelu' return { - 'Conv', 'Gemm' + 'Conv', 'Gemm', 'Concat', 'Add', 'LayerNorm', 'Gelu', } @ property diff --git a/ppq/quantization/quantizer/base.py b/ppq/quantization/quantizer/base.py index 491f6b45..b8a547b8 100644 --- a/ppq/quantization/quantizer/base.py +++ b/ppq/quantization/quantizer/base.py @@ -232,8 +232,13 @@ def report(self) -> str: debug_str += '--------- Network Snapshot ---------\n' debug_str += f'Num of Op: [{len(self._graph.operations)}]\n' debug_str += f'Num of Quantized Op: [{len(quant_ops)}]\n' + # import pdb + # pdb.set_trace() + debug_str += " ".join([op.name for op in quant_ops]) + debug_str += '\n' debug_str += f'Num of Variable: [{len(self._graph.variables)}]\n' debug_str += f'Num of Quantized Var: [{len(quant_vars)}]\n' + debug_str += " ".join([var.name for var in quant_vars]) debug_str += '------- Quantization Snapshot ------\n' debug_str += f'Num of Quant Config: [{len(quant_cfgs)}]\n' for state, cnt in config_states_cnt.items(): From 00b60e6ec2d8208a7c9b0978b1267da79601e456 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 6 Jun 2022 22:22:31 +0800 Subject: [PATCH 04/22] feat(NCNNExporter): add new opr --- ppq/core/common.py | 2 +- ppq/parser/ncnn_exporter.py | 80 ++++++++++++++------- ppq/quantization/optim/parameters.py | 4 +- ppq/quantization/quantizer/NCNNQuantizer.py | 7 +- 4 files changed, 62 insertions(+), 31 deletions(-) diff --git a/ppq/core/common.py b/ppq/core/common.py index 5444a0bb..5ee73d96 100644 --- a/ppq/core/common.py +++ b/ppq/core/common.py @@ -29,7 +29,7 @@ LINEAR_ACTIVATIONS = {'Relu', 'Clip'} # COPUTING OP 是所有计算层,该属性被用于联合定点和子图切分 -COMPUTING_OP = {'Conv', 'Gemm', 'ConvTranspose', 'MatMul'} +COMPUTING_OP = {'Conv', 'Gemm', 'ConvTranspose', 'MatMul', 'LayerNorm', 'MultiHeadAttention'} # SOI OP 是所有产生 SOI 输出的节点类型,该属性被用于子图切分 SOI_OP = {'TopK', 'Shape', 'NonMaxSuppression'} # 强制联合定点的算子种类 diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 5bf5210c..dfcbf073 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -15,33 +15,65 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph): topo_order = graph.topological_sort() import pdb pdb.set_trace() + # write weight scale for op in topo_order: - if op.is_computing_op and isinstance(op, QuantableOperation): - fd.write(f'{op.name}_param_0 ') - param_cfg = op.config.input_quantization_config[1] - assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ - and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ - param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) - # a workaround for depthwise conv in ncnn - # will cause mis-alignment between ppq and ncnn - if op.type == 'Conv' and op.attributes.get('group', 1) > 1: - group = op.attributes.get('group', 1) - scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] + if hasattr(op, 'config'): + if op.type in {'Conv', 'Gemm'}: + fd.write(f'{op.name}_param_0 ') + param_cfg = op.config.input_quantization_config[1] + assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ + and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ + param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) + # a workaround for depthwise conv in ncnn + # will cause mis-alignment between ppq and ncnn + if op.type == 'Conv' and op.attributes.get('group', 1) > 1: + group = op.attributes.get('group', 1) + scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] + else: + scale = param_cfg.scale + scale = convert_value(1 / scale, False, DataType.FP32) + for s in scale: + fd.write('%f '% s) + fd.write('\n') + elif op.type in {'Concat', 'Add', 'Gelu', 'LayerNorm'}: + # Concat and Add has no weight, skip + pass + elif op.type == 'MultiHeadAttention': + # TODO + pass else: - scale = param_cfg.scale - scale = convert_value(1 / scale, False, DataType.FP32) - for s in scale: - fd.write('%f '% s) - fd.write('\n') + print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name)) for op in topo_order: - if op.is_computing_op and isinstance(op, QuantableOperation): - fd.write(f'{op.name} ') - input_cfg = op.config.input_quantization_config[0] - assert input_cfg.state == QuantizationStates.ACTIVATED and\ - input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) - scale = convert_value(1 / input_cfg.scale, True, DataType.FP32) - fd.write('%f '% scale) - fd.write('\n') + if hasattr(op, 'config'): + # write input scale + if op.type in {'Conv', 'Gemm'}: + fd.write(f'{op.name} ') + input_cfg = op.config.input_quantization_config[0] + assert input_cfg.state == QuantizationStates.ACTIVATED and\ + input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) + scale = convert_value(1 / input_cfg.scale, True, DataType.FP32) + fd.write('%f '% scale) + fd.write('\n') + elif op.type in {'Concat', 'Add', 'Gelu', 'LayerNorm'}: + fd.write(f'{op.name} ') + for cfg in op.config.input_quantization_config: + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + and cfg.observer_algorithm in {'minmax', 'Minmax'} + scale = convert_value(1.0 / cfg.scale, True, DataType.FP32) + + if type(scale) == list: + for s in scale: + fd.write('%f '% s) + else: + fd.write('%f '% scale) + + fd.write('\n') + elif op.type == 'MultiHeadAttention': + # TODO + pass + else: + print('unknown quant type {} name {} during write input scale'.format(op.type, op.name)) + fd.close() def export(self, file_path: str, graph: BaseGraph, config_path: str = None, input_shapes: List[List[int]] = [[1, 3, 224, 224]]): diff --git a/ppq/quantization/optim/parameters.py b/ppq/quantization/optim/parameters.py index 1888124d..7727f05e 100644 --- a/ppq/quantization/optim/parameters.py +++ b/ppq/quantization/optim/parameters.py @@ -125,10 +125,10 @@ def optimize( state_records[config] = config.state config.state = QuantizationStates.DEQUANTIZED elif self._method is not None: - config.observer_algorithm = 'Minmax' + # config.observer_algorithm = 'Minmax' # pass # override quantizer's setting if necessary - # config.observer_algorithm = self._method + config.observer_algorithm = self._method observer = OperationObserver( operation=executor._graph.operations[op_name], diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 95eb3e53..a6d97c98 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -39,7 +39,7 @@ def init_quantize_config( observer_algorithm='Minmax' ) - if operation.type in {'Conv', 'Gemm', 'Add', 'LayerNorm', 'MultiHeadAttention', 'Gelu', 'Concat'}: + if operation.type in {'Add', 'Concat', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu'}: assert operation.num_of_input > 0, 'Seems you got a Computing layer with no parameters.' if operation.type == 'Conv': @@ -119,7 +119,7 @@ def init_quantize_config( base_quant_config.input_quantization_config[0] = \ ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( convert_from = inp_config, - offsets = None, scales = None, channel_axis = 0 + offsets = None, scales = None, channel_axis = 1 ) base_quant_config.input_quantization_config[0].observer_algorithm = 'Minmax' @@ -143,9 +143,8 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: - # 'Conv', 'Gemm', 'Concat', 'Add', 'LayerNorm', 'MultiHeadAttention', 'Gelu' return { - 'Conv', 'Gemm', 'Concat', 'Add', 'LayerNorm', 'Gelu', + 'Add', 'Concat', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' } @ property From b3a6369cef2f0e6745a0c2561cb04d9c85a5b077 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 7 Jun 2022 10:35:35 +0800 Subject: [PATCH 05/22] fix(NCNNExporter): layernorm input scales --- ppq/IR/deploy.py | 12 +++---- ppq/core/data.py | 30 ++++++++-------- ppq/parser/ncnn_exporter.py | 38 ++++++++++++++------- ppq/parser/onnx_exporter.py | 2 +- ppq/parser/util.py | 6 ++-- ppq/quantization/quantizer/NCNNQuantizer.py | 25 ++++---------- 6 files changed, 57 insertions(+), 56 deletions(-) diff --git a/ppq/IR/deploy.py b/ppq/IR/deploy.py index 2316a292..7dbd8a05 100644 --- a/ppq/IR/deploy.py +++ b/ppq/IR/deploy.py @@ -68,7 +68,7 @@ def retrieve(self): for _, variable in self._graph.variables.items(): assert isinstance(variable, Variable), \ f'Failed to send graph to device, incorrect variable {variable} found.' - variable.value = convert_any_to_numpy(variable.value, accepet_none=True) + variable.value = convert_any_to_numpy(variable.value, accept_none=True) return self @@ -84,12 +84,12 @@ def deploy(self, device: str): if operator.type == 'Constant' and operator.platform != TargetPlatform.SHAPE_OR_INDEX: operator.attributes['value'] = \ convert_any_to_torch_tensor( - operator.attributes['value'], accepet_none=False).to(device) + operator.attributes['value'], accept_none=False).to(device) if operator.type == 'Constant' and operator.platform == TargetPlatform.SHAPE_OR_INDEX: value = operator.attributes['value'] operator.attributes['value'] = convert_any_to_torch_tensor( - value, accepet_none=False, device='cpu') + value, accept_none=False, device='cpu') for _, variable in self._graph.variables.items(): assert isinstance(variable, Variable), \ @@ -108,10 +108,10 @@ def deploy(self, device: str): # if all downstream operations are shape related operations, send value to cpu if platform == TargetPlatform.SHAPE_OR_INDEX: variable.value = convert_any_to_torch_tensor( - variable.value, accepet_none=True).to('cpu') + variable.value, accept_none=True).to('cpu') else: variable.value = convert_any_to_torch_tensor( - variable.value, accepet_none=True).to(device=device) + variable.value, accept_none=True).to(device=device) # if variable is a shape-related variable, send it to cpu. if variable.is_parameter: @@ -125,5 +125,5 @@ def deploy(self, device: str): if dest_op.type in {'Reshape', 'Slice', 'Gather', 'Pad', 'Resize', 'Split', 'TopK', 'Tile', 'Expand'}: if dest_idx >= 1 and len(variable.dest_ops) == 1: variable.value = convert_any_to_torch_tensor( - variable.value, accepet_none=True).to('cpu') + variable.value, accept_none=True).to('cpu') return self diff --git a/ppq/core/data.py b/ppq/core/data.py index ff4dee43..d171d366 100644 --- a/ppq/core/data.py +++ b/ppq/core/data.py @@ -212,20 +212,20 @@ def num_of_output(self): def convert_any_to_python_primary_type( x: Union[torch.Tensor, np.ndarray, int, float, list, str], - accepet_none: bool=True) -> Union[int, float, list, str]: - if x is None and accepet_none: return None - if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.') + accept_none: bool=True) -> Union[int, float, list, str]: + if x is None and accept_none: return None + if x is None and not accept_none: raise ValueError('Trying to convert an empty value.') if isinstance(x, list) or isinstance(x, tuple): return list(x) elif isinstance(x, int) or isinstance(x, float): return x elif isinstance(x, torch.Tensor): - if x.numel() == 0 and accepet_none: return None - if x.numel() == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.') + if x.numel() == 0 and accept_none: return None + if x.numel() == 0 and not accept_none: raise ValueError('Trying to convert an empty value.') if str(x.device) != 'cpu': x = x.cpu() if x.numel() == 1: return x.item() if x.numel() > 1: return x.tolist() elif isinstance(x, np.ndarray): - if x.size == 0 and accepet_none: return None - if x.size == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.') + if x.size == 0 and accept_none: return None + if x.size == 0 and not accept_none: raise ValueError('Trying to convert an empty value.') if x.size == 1: return x.reshape((1, )).tolist()[0] if x.size > 1: return x.tolist() elif isinstance(x, str): @@ -236,14 +236,14 @@ def convert_any_to_python_primary_type( def convert_any_to_numpy( x: Union[torch.Tensor, np.ndarray, int, float, list, tuple], - accepet_none: bool=True) -> np.ndarray: - if x is None and accepet_none: return None - if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.') + accept_none: bool=True) -> np.ndarray: + if x is None and accept_none: return None + if x is None and not accept_none: raise ValueError('Trying to convert an empty value.') if isinstance(x, np.ndarray): return x elif isinstance(x, int) or isinstance(x, float): return np.array([x, ]) elif isinstance(x, torch.Tensor): - if x.numel() == 0 and accepet_none: return None - if x.numel() == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.') + if x.numel() == 0 and accept_none: return None + if x.numel() == 0 and not accept_none: raise ValueError('Trying to convert an empty value.') if x.numel() == 1: return convert_any_to_numpy(x.detach().cpu().item()) if x.numel() > 1: return x.detach().cpu().numpy() elif isinstance(x, list) or isinstance(x, tuple): @@ -254,9 +254,9 @@ def convert_any_to_numpy( def convert_any_to_torch_tensor( x: Union[torch.Tensor, np.ndarray, int, float, list, tuple], - accepet_none: bool=True, dtype: torch.dtype=None, device='cpu') -> torch.Tensor: - if x is None and accepet_none: return None - if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.') + accept_none: bool=True, dtype: torch.dtype=None, device='cpu') -> torch.Tensor: + if x is None and accept_none: return None + if x is None and not accept_none: raise ValueError('Trying to convert an empty value.') if isinstance(x, list) or isinstance(x, tuple): if all([type(element) == int for element in x]): if dtype is None: dtype=torch.int64 diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index dfcbf073..33ca8648 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -7,6 +7,7 @@ from .caffe_exporter import CaffeExporter from .onnx_exporter import OnnxExporter from .util import convert_value +from collections import Iterable class NCNNExporter(GraphExporter): @@ -37,43 +38,54 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph): fd.write('\n') elif op.type in {'Concat', 'Add', 'Gelu', 'LayerNorm'}: # Concat and Add has no weight, skip - pass + continue elif op.type == 'MultiHeadAttention': # TODO - pass + continue else: print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name)) for op in topo_order: if hasattr(op, 'config'): # write input scale - if op.type in {'Conv', 'Gemm'}: - fd.write(f'{op.name} ') + print('op type {} name {}'.format(op.type, op.name)) + fd.write(f'{op.name} ') + if op.type in {'Conv', 'Gemm'}: input_cfg = op.config.input_quantization_config[0] - assert input_cfg.state == QuantizationStates.ACTIVATED and\ + assert input_cfg.state == QuantizationStates.ACTIVATED and \ input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) scale = convert_value(1 / input_cfg.scale, True, DataType.FP32) fd.write('%f '% scale) - fd.write('\n') - elif op.type in {'Concat', 'Add', 'Gelu', 'LayerNorm'}: - fd.write(f'{op.name} ') + elif op.type in {'Concat', 'Add', 'Gelu'}: for cfg in op.config.input_quantization_config: - assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + print('cfg state {} algo {}'.format(cfg.state, cfg.observer_algorithm)) + + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED, QuantizationStates.SLAVE} \ and cfg.observer_algorithm in {'minmax', 'Minmax'} scale = convert_value(1.0 / cfg.scale, True, DataType.FP32) - if type(scale) == list: + if isinstance(scale, Iterable): for s in scale: fd.write('%f '% s) else: fd.write('%f '% scale) - - fd.write('\n') + elif op.type == 'LayerNorm': + cfg = op.config.input_quantization_config[0] + print('cfg state {} algo {}'.format(cfg.state, cfg.observer_algorithm)) + + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + and cfg.observer_algorithm in {'minmax', 'Minmax'} + scale = convert_value(1.0 / cfg.scale, False, DataType.FP32) + if isinstance(scale, Iterable): + for s in scale: + fd.write('%f '% s) + else: + fd.write('%f '% scale) elif op.type == 'MultiHeadAttention': # TODO pass else: print('unknown quant type {} name {} during write input scale'.format(op.type, op.name)) - + fd.write('\n') fd.close() def export(self, file_path: str, graph: BaseGraph, config_path: str = None, input_shapes: List[List[int]] = [[1, 3, 224, 224]]): diff --git a/ppq/parser/onnx_exporter.py b/ppq/parser/onnx_exporter.py index ac2231fe..4ce2af5a 100644 --- a/ppq/parser/onnx_exporter.py +++ b/ppq/parser/onnx_exporter.py @@ -60,7 +60,7 @@ def export(self, operation: Operation, graph: BaseGraph, **kwargs) -> Operation: def convert_value(value: Union[int, float, np.ndarray, torch.Tensor]) -> str: if type(value) in {int, float}: return value else: - value = convert_any_to_numpy(value, accepet_none=True) + value = convert_any_to_numpy(value, accept_none=True) if value is None: return value # SOI config has Nona as its scale and return value.tolist() diff --git a/ppq/parser/util.py b/ppq/parser/util.py index c2bffb9f..232003a1 100644 --- a/ppq/parser/util.py +++ b/ppq/parser/util.py @@ -8,6 +8,7 @@ def convert_value( value: Union[int, float, np.ndarray, torch.Tensor], export_as_float: bool, dtype: DataType = DataType.FP32) -> Union[float, list]: + """Converting value from any to python native data dtype, ready for export. Args: @@ -18,10 +19,11 @@ def convert_value( Returns: Union[float, list]: Converted value """ + if dtype not in {DataType.FP32, DataType.INT32}: raise ValueError(f'Can Only export dtype fp32 and int32, ' f'while you are requiring to dump a {dtype.name} value') - value = convert_any_to_numpy(value, accepet_none=False) + value = convert_any_to_numpy(value, accept_none=False) value = value.astype(dtype=DataType.to_numpy(dtype)) if export_as_float: value = np.asscalar(value[0]) @@ -30,5 +32,5 @@ def convert_value( f'It is Expected to be a int or float value, while {type(value)} was given') return value else: - value = convert_any_to_numpy(value, accepet_none=False) + value = convert_any_to_numpy(value, accept_none=False) return value.tolist() diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index a6d97c98..fcf41f5d 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -81,24 +81,7 @@ def init_quantize_config( offsets = None, scales = None, channel_axis = 0 ) base_quant_config.input_quantization_config[1].observer_algorithm = 'Minmax' - - elif operation.type == 'Concat': - base_quant_config.input_quantization_config[1].policy = QuantizationPolicy( - QuantizationProperty.SYMMETRICAL + - QuantizationProperty.LINEAR + - QuantizationProperty.PER_TENSOR - ) - base_quant_config.input_quantization_config[1].observer_algorithm = 'Minmax' - - elif operation.type == 'Add': - # Add 量化输入 - base_quant_config.input_quantization_config[1].policy = QuantizationPolicy( - QuantizationProperty.SYMMETRICAL + - QuantizationProperty.LINEAR + - QuantizationProperty.PER_TENSOR - ) - base_quant_config.input_quantization_config[1].observer_algorithm = 'Minmax' - + elif operation.type == 'LayerNorm': # LayerNorm 输入按 power of 2 量化 @@ -128,7 +111,11 @@ def init_quantize_config( bconfig = base_quant_config.input_quantization_config[2] wconfig.state = QuantizationStates.FP32 bconfig.state = QuantizationStates.FP32 - + + elif operation.type in {'Add', 'Concat'}: + # use default param + pass + # 显式说明输出不量化 base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32 return base_quant_config From 216d2910759a185a0e29224780e51151b9400635 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 8 Jun 2022 22:02:12 +0800 Subject: [PATCH 06/22] feat(model): add PER_CHANNEL_BNC --- doc/pages/instructions/ppq_quant_1.html | 3 +- ppq/core/quant.py | 23 +++++++----- ppq/csrc/cuda/PPQ.h | 7 ++-- ppq/quantization/algorithm/training.py | 20 +++++++++-- ppq/quantization/observer/range.py | 39 ++++++++++++++------- ppq/quantization/optim/calibration.py | 5 +-- ppq/quantization/optim/training.py | 2 ++ ppq/quantization/qfunction/linear.py | 6 ++++ ppq/quantization/quantizer/NCNNQuantizer.py | 12 ++----- 9 files changed, 80 insertions(+), 37 deletions(-) diff --git a/doc/pages/instructions/ppq_quant_1.html b/doc/pages/instructions/ppq_quant_1.html index f674d985..486f8126 100644 --- a/doc/pages/instructions/ppq_quant_1.html +++ b/doc/pages/instructions/ppq_quant_1.html @@ -183,7 +183,8 @@

QuantizationPolicy

QuantizationPolicy 在 PPQ 中用来描述量化策略,它是一些 QuantizationProperty 枚举的组合位图。在 PPQ 中我们支持的 QuantizationProperty 包括:

  • PER_TENSOR:逐层量化。
  • -
  • PER_CHANNEL:逐通道量化。
  • +
  • PER_CHANNEL:CNN 模型逐通道量化。
  • +
  • PER_CHANNEL_BNC:Transformer 模型逐通道量化。
  • LINEAR: 线性量化。
  • EXPONENTIAL: 指数量化。
  • SYMMETRICAL: 对称量化。
  • diff --git a/ppq/core/quant.py b/ppq/core/quant.py index d604c140..56f555c5 100644 --- a/ppq/core/quant.py +++ b/ppq/core/quant.py @@ -140,7 +140,9 @@ class QuantizationProperty(Enum): PER_TENSOR: Also known as per-layer quantization, which mean all parameters of this layer share the same scale and offset. (For Convulution layer and Gemm layer which has bias, bias layer will be negative quantized, they do not have a valid scale) - PER_CHANNEL: parameters are quantized alone channel axis, each channel has a stand-alone scale and offset. + PER_CHANNEL: CNN model parameters are quantized along channel axis, each channel has a stand-alone scale and offset. + + PER_CHANNEL_BNC: transformer model parameters are quantized per-chanel, each channel has a stand-alone scale and offset. LINEAR: Indicates a linear quantization, follow formula: quant(x) = clip(round(x / scale)) @@ -151,18 +153,22 @@ class QuantizationProperty(Enum): ASYMMETRICAL: Indicates an asymmetrical quantization, offset is activated in this mode. POWER_OF_2: Indicates a power-of-2 quantization, scale must be pow(2, k) in this mode. + + PTF_BNC: Indicates a power-of-2 for quantization for layernorm input, scale must be pow(2, k) in this mode. ATTENTION: Not all combinations of all 7 QuantizationProperty are valid, see QuantizationPolicy.__check_valid ATTENTION: QuantizationPolicy is read-only, user can only assign its value when created, the only interface of QuantizationPolicy is function QuantizationPolicy.has_property. """ - PER_TENSOR = 0x00000001 - PER_CHANNEL = 0x00000002 - LINEAR = 0x00000004 - EXPONENTIAL = 0x00000008 - SYMMETRICAL = 0x00000010 - ASYMMETRICAL = 0x00000020 - POWER_OF_2 = 0x00000040 + PER_TENSOR = 0x00000001 + PER_CHANNEL = 0x00000002 + PER_CHANNEL_BNC = 0x00000080 + LINEAR = 0x00000004 + EXPONENTIAL = 0x00000008 + SYMMETRICAL = 0x00000010 + ASYMMETRICAL = 0x00000020 + POWER_OF_2 = 0x00000040 + PTF_BNC = 0x00000100 def __or__(self, other: int) -> int: return self.value + other @@ -247,6 +253,7 @@ def __check_valid(cls, policy): QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_TENSOR | QuantizationProperty.POWER_OF_2, QuantizationProperty.ASYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL | QuantizationProperty.POWER_OF_2, QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL | QuantizationProperty.POWER_OF_2, + QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL_BNC | QuantizationProperty.PTF_BNC, } def to_dict(self) -> dict: diff --git a/ppq/csrc/cuda/PPQ.h b/ppq/csrc/cuda/PPQ.h index d5138fb6..d1d5fed8 100644 --- a/ppq/csrc/cuda/PPQ.h +++ b/ppq/csrc/cuda/PPQ.h @@ -112,13 +112,16 @@ ATTENTION: QuantizationPolicy is read-only, user can only assign its value when created, the only interface of QuantizationPolicy is function QuantizationPolicy.has_property. */ - # define PPQ_QPROPERTY_PER_TENSOR 0x00000001 - # define PPQ_QPROPERTY_PER_CHANNEL 0x00000002 + # define PPQ_QPROPERTY_PER_TENSOR 0x00000001 + # define PPQ_QPROPERTY_PER_CHANNEL 0x00000002 + # define PPQ_QPROPERTY_PER_CHANNEL_BNC 0x00000080 + # define PPQ_QPROPERTY_LINEAR 0x00000004 # define PPQ_QPROPERTY_EXPONENTIAL 0x00000008 # define PPQ_QPROPERTY_SYMMETRICAL 0x00000010 # define PPQ_QPROPERTY_ASYMMETRICAL 0x00000020 # define PPQ_QPROPERTY_POWER_OF_2 0x00000040 + # define PPQ_QPROPERTY_PTF_BNC 0x00000100 # define PPQ_CPP_EXTENSION # endif diff --git a/ppq/quantization/algorithm/training.py b/ppq/quantization/algorithm/training.py index 7b5f2d0d..9712111c 100644 --- a/ppq/quantization/algorithm/training.py +++ b/ppq/quantization/algorithm/training.py @@ -221,9 +221,11 @@ def PPQTensorClip( limit: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor: if config.policy.has_property(QuantizationProperty.PER_CHANNEL): assert isinstance(config, ChannelwiseTensorQuantizationConfig) - return Clip_C.apply(tensor, reference, limit, config.channel_axis) + return Clip_C.apply(tensor, reference, limit, config.channel_axis) elif config.policy.has_property(QuantizationProperty.PER_TENSOR): return Clip_T.apply(tensor, reference, limit) + elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise Exception('Not implement PER_CHANNEL_BNC in PPQTensorClip') else: raise Exception('Oops, seems we got some problems here.') @@ -240,6 +242,8 @@ def PPQRoundingLoss(tensor: torch.Tensor, return RoundingLoss_T.apply( tensor, config.scale, config.offset, config.quant_min, config.quant_max, config.rounding) + elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise Exception('Not implement PER_CHANNEL_BNC in PPQRoundingLoss') else: raise Exception('Oops, seems we got some problems here.') @@ -743,7 +747,10 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to shape = [1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)] scale = scale.view(shape) bias = bias.view(shape) - + + if config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise Exception('Not implement PER_CHANNEL_BNC in StraightThroughEstimateDelegator') + # only bias doesn't need offset in asym quant if not self.passive and config.policy.has_property(QuantizationProperty.ASYMMETRICAL): tensor = tensor + bias.abs() @@ -782,6 +789,9 @@ def initiate_rounding(self) -> Union[None, torch.nn.Parameter]: assert isinstance(self.config, ChannelwiseTensorQuantizationConfig) shape = [1 if axis != self.config.channel_axis else -1 for axis in range(weight.ndim)] scale = scale.view(shape) + elif self.config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise Exception('Not implement PER_CHANNEL_BNC in StraightThroughEstimateDelegator') + round_diff = (weight / scale) - (weight / scale).floor() v_init = -torch.log((self.reg.zeta - self.reg.gamma) / (round_diff - self.reg.gamma) - 1) continuous_v = torch.nn.Parameter(v_init, True) @@ -813,6 +823,9 @@ def finalize(self) -> None: shape = [1 if axis != self.config.channel_axis else -1 for axis in range(weight.ndim)] scale = scale.view(shape) offset = offset.view(shape) + elif self.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise Exception('Not implement PER_CHANNEL_BNC in BlockwiseReconstructionDelegator finalize') + weight = (weight / scale).floor() + (self.rounding >= 0).float() weight = torch.clamp(weight + offset, self.config.quant_min, self.config.quant_max) weight = (weight - offset) * scale @@ -829,6 +842,9 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to shape = [1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)] scale = scale.view(shape) offset = offset.view(shape) + elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise Exception('Not implement PER_CHANNEL_BNC in BlockwiseReconstructionDelegator __call__') + tensor = (tensor / scale).floor() + self.reg.rectified_sigmoid(self.rounding) tensor = torch.clamp(tensor + offset, config.quant_min, config.quant_max) tensor = (tensor - offset) * scale diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index 42f00a43..7a54b24b 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -66,6 +66,14 @@ def observe(self, value: torch.Tensor): channelwise_view = torch.flatten(channelwise_view, start_dim=1) self._min_val_collector.append(torch.min(channelwise_view, dim=1, keepdim=True)[0]) self._max_val_collector.append(torch.max(channelwise_view, dim=1, keepdim=True)[0]) + elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + import pdb + pdb.set_trace() + assert len(value.shape) == 3 + channelwise_view = value.reshape(-1, value.shape[-1]) + channelwise_view = channelwise_view.transpose(0,1) + self._min_val_collector.append(torch.min(channelwise_view, dim=1, keepdim=True)[0]) + self._max_val_collector.append(torch.max(channelwise_view, dim=1, keepdim=True)[0]) else: raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.') @@ -84,7 +92,9 @@ def render_quantization_config(self): self._quant_cfg.offset = torch.tensor([offset], dtype=torch.float32, device=device).squeeze(0) self._quant_cfg.state = QuantizationStates.ACTIVATED - elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL): + elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) \ + or self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + min_vals = torch.min(torch.cat(self._min_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() max_vals = torch.max(torch.cat(self._max_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() assert(len(min_vals) == len(max_vals)), 'Min values and max values should at same length.' @@ -260,19 +270,20 @@ def observe(self, value: torch.Tensor): self._percentile_collector.append(torch.cat([_max, _min], dim=-1)) else: self._percentile_collector.append(CUDA.Quantile(value, self._percentile).view(1, -1)) - elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL): + elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or \ + self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): import pdb pdb.set_trace() raise PermissionError('Percentile observer can not deal with per channel quantization.') - - assert isinstance(self._quant_cfg, ChannelwiseTensorQuantizationConfig), ( - 'Your quantization config has PER_CHANNEL while it is not a ' - 'ChannelwiseTensorQuantizationConfig instance.') - channel_axis = self._quant_cfg.channel_axis - channelwise_view = value.transpose(dim0=0, dim1=channel_axis) - channelwise_view = torch.flatten(channelwise_view, start_dim=1) - self._percentile_mins.append(-torch.quantile(-channelwise_view, q=self._percentile, dim=1, keepdim=True)[0]) - self._percentile_maxs.append(torch.quantile(channelwise_view, q=self._percentile, dim=1, keepdim=True)[0]) + + # assert isinstance(self._quant_cfg, ChannelwiseTensorQuantizationConfig), ( + # 'Your quantization config has PER_CHAN=NEL while it is not a ' + # 'ChannelwiseTensorQuantizationConfig instance.') + # channel_axis = self._quant_cfg.channel_axis + # channelwise_view = value.transpose(dim0=0, dim1=channel_axis) + # channelwise_view = torch.flatten(channelwise_view, start_dim=1) + # self._percentile_mins.append(-torch.quantile(-channelwise_view, q=self._percentile, dim=1, keepdim=True)[0]) + # self._percentile_maxs.append(torch.quantile(channelwise_view, q=self._percentile, dim=1, keepdimTrue)[0]) else: raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.') @@ -291,7 +302,8 @@ def render_quantization_config(self): self._quant_cfg.scale = torch.tensor([scale], dtype=torch.float32, device=device).squeeze(0) self._quant_cfg.offset = torch.tensor([offset], dtype=torch.float32, device=device).squeeze(0) self._quant_cfg.state = QuantizationStates.ACTIVATED - elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL): + elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or \ + self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): import pdb pdb.set_trace() raise PermissionError('Percentile observer can not deal with per channel quantization.') @@ -353,6 +365,9 @@ def hist_to_scale_offset( if config.policy.has_property(QuantizationProperty.PER_CHANNEL): raise PermissionError('Torch Mse observer do not support PER_CHANNEL policy now, please wait.') + if config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise PermissionError('Torch Mse observer do not support PER_CHANNEL_BNC policy now, please wait.') + if (config.policy.has_property(QuantizationProperty.SYMMETRICAL) and config.policy.has_property(QuantizationProperty.PER_TENSOR)): scale = 1 # hist scale diff --git a/ppq/quantization/optim/calibration.py b/ppq/quantization/optim/calibration.py index 87a9f443..94acd4b0 100644 --- a/ppq/quantization/optim/calibration.py +++ b/ppq/quantization/optim/calibration.py @@ -311,11 +311,12 @@ def optimize(self, processor: GraphCommandProcessor, dataloader: Iterable, assert isinstance(cfg, TensorQuantizationConfig) assert isinstance(observer, TorchMinMaxObserver) - if observer._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL): + if observer._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or\ + observer._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): min_vals = torch.min(torch.cat(observer._min_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() max_vals = torch.max(torch.cat(observer._max_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() cfg.detail.update({'range_min': min_vals, 'range_max': max_vals}) - + elif observer._quant_cfg.policy.has_property(QuantizationProperty.PER_TENSOR): min_val = torch.min(torch.cat(observer._min_val_collector, dim=0)).cpu().item(), max_val = torch.max(torch.cat(observer._max_val_collector, dim=0)).cpu().item(), diff --git a/ppq/quantization/optim/training.py b/ppq/quantization/optim/training.py index 81a46912..dd28f0ee 100644 --- a/ppq/quantization/optim/training.py +++ b/ppq/quantization/optim/training.py @@ -490,6 +490,8 @@ def optimize( for axis in range(fp_weight.ndim)] weight_scale = weight_scale.view(view_shape) weight_offset = weight_offset.view(view_shape) + elif weight_quantization_config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + raise Exception('Not implement PER_CHANNEL_BNC in AdaRoundPass optimize') # init continuous_v, make sure h(v) = round_diff round_diff = (fp_weight / weight_scale) - (fp_weight / weight_scale).floor() diff --git a/ppq/quantization/qfunction/linear.py b/ppq/quantization/qfunction/linear.py index 95310b20..7ae596a2 100644 --- a/ppq/quantization/qfunction/linear.py +++ b/ppq/quantization/qfunction/linear.py @@ -200,6 +200,12 @@ def PPQLinearQuant_toInt(tensor: torch.Tensor, config: TensorQuantizationConfig, elif config.policy.has_property(QuantizationProperty.PER_TENSOR): tensor = ppq_tensor_round((tensor / config.scale), config.rounding) + config.offset tensor = torch.clamp(tensor, config.quant_min, config.quant_max) + elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + assert len(tensor.shape()) == 3 + scale = config.scale.reshape((1,1,-1)) + offset = config.offset.reshape((1,1,-1)) + tensor = ppq_tensor_round((tensor / scale), config.rounding) + offset + tensor = torch.clamp(tensor, config.quant_min, config.quant_max) return tensor.type(dtype=torch.int32) diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index fcf41f5d..ae959b3e 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -84,20 +84,12 @@ def init_quantize_config( elif operation.type == 'LayerNorm': # LayerNorm 输入按 power of 2 量化 - - # base_quant_config.input_quantization_config[0].policy = QuantizationPolicy( - # QuantizationProperty.SYMMETRICAL + - # QuantizationProperty.LINEAR + - # QuantizationProperty.PER_TENSOR - # ) - # base_quant_config.input_quantization_config[0].observer_algorithm = 'Minmax' - inp_config = base_quant_config.input_quantization_config[0] inp_config.policy = QuantizationPolicy( QuantizationProperty.SYMMETRICAL + QuantizationProperty.LINEAR + - QuantizationProperty.PER_CHANNEL + - QuantizationProperty.POWER_OF_2 + QuantizationProperty.PER_CHANNEL_BNC + + QuantizationProperty.PTF_BNC ) base_quant_config.input_quantization_config[0] = \ ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( From 2af789aef3f8e6c4c3dad7e9361caa3efa13f36b Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Thu, 9 Jun 2022 09:35:37 +0800 Subject: [PATCH 07/22] feat(range.py): add PTF BNC to get scale and offset --- ppq/quantization/observer/range.py | 64 ++++++++++++++++++++++++++++-- ppq/quantization/observer/utils.py | 10 +++++ 2 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 ppq/quantization/observer/utils.py diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index 7a54b24b..7bd4699a 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -14,8 +14,53 @@ from ppq.utils.round import ppq_numerical_round, ppq_round_to_power_of_2 from .base import BaseTensorObserver +from .utils import lp_loss +@ ppq_quant_param_computing_function +def PTF_BNC_to_scale_offset( + min_val: list, max_val: list, + inputs: torch.Tensor, + config: TensorQuantizationConfig +) -> Tuple[torch.Tensor, torch.Tensor]: + max_val = torch.Tensor(max_val) + min_val = torch.Tensor(min_val) + + import pdb + pdb.set_trace() + + qmax = config.quant_max + qmin = config.quant_min + + max_val_t = max_val.max() + min_val_t = min_val.min() + scale8 = (max_val_t - min_val_t) / float(qmax - qmin) + scale8.clamp_(1e-7) + scale4 = scale8 / 2 + scale2 = scale4 / 2 + scale1 = scale2 / 2 + zero_point = qmin - torch.round(min_val_t / scale8) + zero_point.clamp_(qmin, qmax) + scale_mask = torch.ones_like(max_val) + for j in range(inputs.shape[2]): + data = inputs[..., j].unsqueeze(-1) + data_q1 = ((data / scale1 + zero_point).round().clamp(qmin, qmax) - + zero_point) * scale1 + data_q2 = ((data / scale2 + zero_point).round().clamp(qmin, qmax) - + zero_point) * scale2 + data_q4 = ((data / scale4 + zero_point).round().clamp(qmin, qmax) - + zero_point) * scale4 + data_q8 = ((data / scale8 + zero_point).round().clamp(qmin, qmax) - + zero_point) * scale8 + score1 = lp_loss(data, data_q1, p=2.0, reduction='all') + score2 = lp_loss(data, data_q2, p=2.0, reduction='all') + score4 = lp_loss(data, data_q4, p=2.0, reduction='all') + score8 = lp_loss(data, data_q8, p=2.0, reduction='all') + score = [score1, score2, score4, score8] + scale_mask[j] *= 2**score.index(min(score)) + scale = scale1 * scale_mask + return scale, zero_point + @ ppq_quant_param_computing_function def minmax_to_scale_offset( min_val: float, max_val: float, @@ -48,6 +93,7 @@ def __init__(self, watch_on: Variable, quant_cfg: TensorQuantizationConfig): super().__init__(watch_on, quant_cfg) self._min_val_collector = [] self._max_val_collector = [] + self._last_input = None @ torch.no_grad() def observe(self, value: torch.Tensor): @@ -74,6 +120,7 @@ def observe(self, value: torch.Tensor): channelwise_view = channelwise_view.transpose(0,1) self._min_val_collector.append(torch.min(channelwise_view, dim=1, keepdim=True)[0]) self._max_val_collector.append(torch.max(channelwise_view, dim=1, keepdim=True)[0]) + self._last_input = value.cpu().clone() else: raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.') @@ -92,9 +139,7 @@ def render_quantization_config(self): self._quant_cfg.offset = torch.tensor([offset], dtype=torch.float32, device=device).squeeze(0) self._quant_cfg.state = QuantizationStates.ACTIVATED - elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) \ - or self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): - + elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL): min_vals = torch.min(torch.cat(self._min_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() max_vals = torch.max(torch.cat(self._max_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() assert(len(min_vals) == len(max_vals)), 'Min values and max values should at same length.' @@ -105,6 +150,19 @@ def render_quantization_config(self): scales.append(scale) offsets.append(offset) + # scale, offset here only deployed on cpu + # we will move them towards target device through RunnableGraph + self._quant_cfg.scale = torch.tensor(scales, dtype=torch.float32, device=device) + self._quant_cfg.offset = torch.tensor(offsets, dtype=torch.float32, device=device) + self._quant_cfg.state = QuantizationStates.ACTIVATED + elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + + min_vals = torch.min(torch.cat(self._min_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() + max_vals = torch.max(torch.cat(self._max_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy() + assert(len(min_vals) == len(max_vals)), 'Min values and max values should at same length.' + + scales, offsets = PTF_BNC_to_scale_offset(min_val=min_vals, max_val=max_vals, inputs=self._last_input, config=self._quant_cfg) + # scale, offset here only deployed on cpu # we will move them towards target device through RunnableGraph self._quant_cfg.scale = torch.tensor(scales, dtype=torch.float32, device=device) diff --git a/ppq/quantization/observer/utils.py b/ppq/quantization/observer/utils.py new file mode 100644 index 00000000..15d6032b --- /dev/null +++ b/ppq/quantization/observer/utils.py @@ -0,0 +1,10 @@ +# Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. +# https://github.com/megvii-research/FQ-ViT +def lp_loss(pred, tgt, p=2.0, reduction='none'): + """ + loss function measured in L_p Norm + """ + if reduction == 'none': + return (pred - tgt).abs().pow(p).sum(1).mean() + else: + return (pred - tgt).abs().pow(p).mean() From 7975f397abe7f6825db859d2df64dd3cfa1701e6 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Fri, 10 Jun 2022 16:58:07 +0800 Subject: [PATCH 08/22] feat(ncnn_exporter): support toml format --- ppq/parser/ncnn_exporter.py | 130 ++++++++++++-------- ppq/quantization/observer/range.py | 9 +- ppq/quantization/optim/parameters.py | 2 - ppq/quantization/quantizer/NCNNQuantizer.py | 6 +- ppq/quantization/quantizer/__init__.py | 2 +- requirements.txt | 1 + 6 files changed, 89 insertions(+), 61 deletions(-) diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 33ca8648..ce81053f 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -1,3 +1,4 @@ +from lib2to3.pytree import convert from typing import List from ppq.core import (DataType, NetworkFramework, QuantizationProperty, @@ -11,16 +12,55 @@ class NCNNExporter(GraphExporter): - def export_quantization_config(self, config_path: str, graph: BaseGraph): + + def export_raw_quant_config(self, config_path: str, graph: BaseGraph): + ''' ncnn table format when version <= 20220420 ''' fd = open(config_path, 'w+') topo_order = graph.topological_sort() - import pdb - pdb.set_trace() - # write weight scale for op in topo_order: + if op.is_computing_op and isinstance(op, QuantableOperation): + fd.write(f'{op.name}_param_0 ') + param_cfg = op.config.input_quantization_config[1] + assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ + and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ + param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) + # a workaround for depthwise conv in ncnn + # will cause mis-alignment between ppq and ncnn + if op.type == 'Conv' and op.attributes.get('group', 1) > 1: + group = op.attributes.get('group', 1) + scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] + else: + scale = param_cfg.scale + scale = convert_value(1 / scale, False, DataType.FP32) + for s in scale: + fd.write('%f '% s) + fd.write('\n') + for op in topo_order: + if op.is_computing_op and isinstance(op, QuantableOperation): + fd.write(f'{op.name} ') + input_cfg = op.config.input_quantization_config[0] + assert input_cfg.state == QuantizationStates.ACTIVATED and\ + input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) + scale = convert_value(1 / input_cfg.scale, True, DataType.FP32) + fd.write('%f '% scale) + fd.write('\n') + fd.close() + + + def export_toml_quant_config(self, config_path: str, graph: BaseGraph): + ''' toml is human readable format ''' + import toml + from ppq.core.config import PPQ_CONFIG + + table = {'source': '{} {}'.format(PPQ_CONFIG.NAME, PPQ_CONFIG.VERSION)} + order = graph.topological_sort() + + for op in order: if hasattr(op, 'config'): - if op.type in {'Conv', 'Gemm'}: - fd.write(f'{op.name}_param_0 ') + item = dict() + # avoiding Gather to Crop, we cannot deduce opr_type from opr_name + item['type'] = op.type + if op.type in {'Conv', 'Gemm'}: param_cfg = op.config.input_quantization_config[1] assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ @@ -32,61 +72,55 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph): scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] else: scale = param_cfg.scale - scale = convert_value(1 / scale, False, DataType.FP32) - for s in scale: - fd.write('%f '% s) - fd.write('\n') - elif op.type in {'Concat', 'Add', 'Gelu', 'LayerNorm'}: - # Concat and Add has no weight, skip - continue - elif op.type == 'MultiHeadAttention': - # TODO - continue - else: - print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name)) - for op in topo_order: - if hasattr(op, 'config'): - # write input scale - print('op type {} name {}'.format(op.type, op.name)) - fd.write(f'{op.name} ') - if op.type in {'Conv', 'Gemm'}: + item['weight'] = convert_value(1 / scale, False, DataType.FP32) + input_cfg = op.config.input_quantization_config[0] assert input_cfg.state == QuantizationStates.ACTIVATED and \ input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) - scale = convert_value(1 / input_cfg.scale, True, DataType.FP32) - fd.write('%f '% scale) - elif op.type in {'Concat', 'Add', 'Gelu'}: + item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32) + + elif op.type in {'Add'}: + # Add may have multiple input node + input_scale = [] + zero_point = [] + for cfg in op.config.input_quantization_config: - print('cfg state {} algo {}'.format(cfg.state, cfg.observer_algorithm)) - assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED, QuantizationStates.SLAVE} \ and cfg.observer_algorithm in {'minmax', 'Minmax'} - scale = convert_value(1.0 / cfg.scale, True, DataType.FP32) - - if isinstance(scale, Iterable): - for s in scale: - fd.write('%f '% s) - else: - fd.write('%f '% scale) - elif op.type == 'LayerNorm': + input_scale.append(convert_value(1.0 / cfg.scale, True, DataType.FP32)) + zero_point.extend(convert_value(cfg.offset, False, DataType.INT32)) + + item['input_scale'] = input_scale + item['zero_point'] = zero_point + + elif op.type in {'LayerNorm', 'Gelu'}: cfg = op.config.input_quantization_config[0] print('cfg state {} algo {}'.format(cfg.state, cfg.observer_algorithm)) - + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ and cfg.observer_algorithm in {'minmax', 'Minmax'} - scale = convert_value(1.0 / cfg.scale, False, DataType.FP32) - if isinstance(scale, Iterable): - for s in scale: - fd.write('%f '% s) - else: - fd.write('%f '% scale) + item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) + item['zero_point'] = convert_value(cfg.offset, False, DataType.INT32) + elif op.type == 'MultiHeadAttention': + # TODO - pass + continue else: - print('unknown quant type {} name {} during write input scale'.format(op.type, op.name)) - fd.write('\n') - fd.close() + print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name)) + + table[op.name] = item + + toml.dump(table, open(config_path, 'w+')) + + + def export_quantization_config(self, config_path: str, graph: BaseGraph): + toml_style = True + if toml_style: + self.export_toml_quant_config(config_path=config_path, graph=graph) + else: + self.export_raw_quant_config(config_path=config_path, graph=graph) + def export(self, file_path: str, graph: BaseGraph, config_path: str = None, input_shapes: List[List[int]] = [[1, 3, 224, 224]]): if config_path is not None: diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index 7bd4699a..e799b085 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -16,7 +16,7 @@ from .base import BaseTensorObserver from .utils import lp_loss - +# https://github.com/megvii-research/FQ-ViT/blob/main/models/ptq/observer/ptf.py#L31 @ ppq_quant_param_computing_function def PTF_BNC_to_scale_offset( min_val: list, max_val: list, @@ -26,9 +26,6 @@ def PTF_BNC_to_scale_offset( max_val = torch.Tensor(max_val) min_val = torch.Tensor(min_val) - import pdb - pdb.set_trace() - qmax = config.quant_max qmin = config.quant_min @@ -113,14 +110,12 @@ def observe(self, value: torch.Tensor): self._min_val_collector.append(torch.min(channelwise_view, dim=1, keepdim=True)[0]) self._max_val_collector.append(torch.max(channelwise_view, dim=1, keepdim=True)[0]) elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): - import pdb - pdb.set_trace() assert len(value.shape) == 3 channelwise_view = value.reshape(-1, value.shape[-1]) channelwise_view = channelwise_view.transpose(0,1) self._min_val_collector.append(torch.min(channelwise_view, dim=1, keepdim=True)[0]) self._max_val_collector.append(torch.max(channelwise_view, dim=1, keepdim=True)[0]) - self._last_input = value.cpu().clone() + self._last_input = value else: raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.') diff --git a/ppq/quantization/optim/parameters.py b/ppq/quantization/optim/parameters.py index 7727f05e..4b45b929 100644 --- a/ppq/quantization/optim/parameters.py +++ b/ppq/quantization/optim/parameters.py @@ -125,8 +125,6 @@ def optimize( state_records[config] = config.state config.state = QuantizationStates.DEQUANTIZED elif self._method is not None: - # config.observer_algorithm = 'Minmax' - # pass # override quantizer's setting if necessary config.observer_algorithm = self._method diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index ae959b3e..9afc3cda 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -39,7 +39,7 @@ def init_quantize_config( observer_algorithm='Minmax' ) - if operation.type in {'Add', 'Concat', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu'}: + if operation.type in {'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu'}: assert operation.num_of_input > 0, 'Seems you got a Computing layer with no parameters.' if operation.type == 'Conv': @@ -104,7 +104,7 @@ def init_quantize_config( wconfig.state = QuantizationStates.FP32 bconfig.state = QuantizationStates.FP32 - elif operation.type in {'Add', 'Concat'}: + elif operation.type in {'Add'}: # use default param pass @@ -123,7 +123,7 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: return { - 'Add', 'Concat', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' + 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' } @ property diff --git a/ppq/quantization/quantizer/__init__.py b/ppq/quantization/quantizer/__init__.py index cfcf8e5f..5744d2c4 100644 --- a/ppq/quantization/quantizer/__init__.py +++ b/ppq/quantization/quantizer/__init__.py @@ -10,5 +10,5 @@ PPLCUDAMixPrecisionQuantizer, PPLCUDAQuantizer) from .TRTQuantizer import TensorRTQuantizer from .FPGAQuantizer import FPGAQuantizer -from .NCNNQuantizer import NCNNQuantizer +from . import NCNNQuantizer from .OpenvinoQuantizer import OpenvinoQuantizer \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 03756530..7e8af6a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ onnx >= 1.8.1 protobuf torch >= 1.6.0 tqdm +toml \ No newline at end of file From 681df759ab54c49a2a6f7bee3b3ccf8f04431774 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Sun, 12 Jun 2022 08:37:23 +0800 Subject: [PATCH 09/22] feat(MultiHeadAttention): add quant mha --- ppq/api/interface.py | 3 ++ ppq/executor/op/torch/default.py | 15 ++++---- ppq/quantization/quantizer/NCNNQuantizer.py | 40 +++++++++++++++++++++ 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/ppq/api/interface.py b/ppq/api/interface.py index 3570b5c5..26c7d9df 100644 --- a/ppq/api/interface.py +++ b/ppq/api/interface.py @@ -657,6 +657,9 @@ def dispatch_graph(graph: BaseGraph, platform: TargetPlatform, setting: Quantiza """ assert platform in QUANTIZER_COLLECTION, ( f'Platform misunderstood, except one of following platform {QUANTIZER_COLLECTION.keys()}') + + import pdb + pdb.set_trace() quantizer = QUANTIZER_COLLECTION[platform](graph) # 初始化一个 quantizer 没有很大代价... if str(setting.dispatcher).lower() not in DISPATCHER_TABLE: diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index d1e9da7b..f796f03a 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -337,7 +337,7 @@ def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: T if len(values) != 11: raise NotImplementedError('Not implement simplified MultiHeadAttention') - q,k,v,q_w,q_b,k_w,k_b,v_w,v_b,o_w,o_b = values + q_in,k_in,v_in,q_w,q_b,k_w,k_b,v_w,v_b,o_w,o_b = values embed_dim = op.attributes.get('embed_dim') num_heads = op.attributes.get('num_heads') @@ -349,17 +349,18 @@ def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: T head_dim = embed_dim // num_heads scale = head_dim ** -0.5 - q = F.linear(q, q_w, q_b) - k = F.linear(k, k_w, k_b) - v = F.linear(v, v_w, v_b) + q = F.linear(q_in, q_w, q_b) + k = F.linear(k_in, k_w, k_b) + v = F.linear(v_in, v_w, v_b) energy = (q @ k.transpose(-2, -1)) * scale attn = energy.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim) - x = F.linear(x, o_w, o_b) + feat = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim) + out = F.linear(feat, o_w, o_b) - return x + # return out, q, k, v, energy, feat + return out def Add_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 9afc3cda..2eba0d3f 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -108,6 +108,46 @@ def init_quantize_config( # use default param pass + elif operation.type == 'MultiHeadAttention': + # setup weight quant param + fc_weight_config = base_quant_config.input_quantization_config[4] + fc_weight_config = QuantizationPolicy( + QuantizationProperty.SYMMETRICAL + + QuantizationProperty.LINEAR + + QuantizationProperty.PER_CHANNEL + ) + base_quant_config.input_quantization_config[3] = \ + ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( + convert_from = fc_weight_config, + offsets = None, scales = None, channel_axis = 0 + ) + base_quant_config.input_quantization_config[5] = \ + ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( + convert_from = fc_weight_config, + offsets = None, scales = None, channel_axis = 0 + ) + base_quant_config.input_quantization_config[7] = \ + ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( + convert_from = fc_weight_config, + offsets = None, scales = None, channel_axis = 0 + ) + base_quant_config.input_quantization_config[9] = \ + ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( + convert_from = fc_weight_config, + offsets = None, scales = None, channel_axis = 0 + ) + + internal_config = QuantizationPolicy( + QuantizationProperty.SYMMETRICAL + + QuantizationProperty.LINEAR + + QuantizationProperty.PER_TENSOR + ) + base_quant_config.output_quantization_config[1] = internal_config + base_quant_config.output_quantization_config[2] = internal_config + base_quant_config.output_quantization_config[3] = internal_config + base_quant_config.output_quantization_config[4] = internal_config + base_quant_config.output_quantization_config[5] = internal_config + # 显式说明输出不量化 base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32 return base_quant_config From 13c07fbe390bfb7fc368a6a00d338f758c512cf3 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Sun, 12 Jun 2022 15:27:52 +0800 Subject: [PATCH 10/22] feat(ncnn_exporter): support mha --- ppq/api/interface.py | 2 - ppq/executor/op/torch/default.py | 21 ++++-- ppq/parser/ncnn_exporter.py | 39 ++++++++++- ppq/quantization/quantizer/NCNNQuantizer.py | 72 ++++++++++++--------- ppq/quantization/quantizer/__init__.py | 2 +- 5 files changed, 96 insertions(+), 40 deletions(-) diff --git a/ppq/api/interface.py b/ppq/api/interface.py index 26c7d9df..ad9cf854 100644 --- a/ppq/api/interface.py +++ b/ppq/api/interface.py @@ -658,8 +658,6 @@ def dispatch_graph(graph: BaseGraph, platform: TargetPlatform, setting: Quantiza assert platform in QUANTIZER_COLLECTION, ( f'Platform misunderstood, except one of following platform {QUANTIZER_COLLECTION.keys()}') - import pdb - pdb.set_trace() quantizer = QUANTIZER_COLLECTION[platform](graph) # 初始化一个 quantizer 没有很大代价... if str(setting.dispatcher).lower() not in DISPATCHER_TABLE: diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index f796f03a..4b29bb80 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -333,7 +333,21 @@ def Mul_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCont return multiplicand * multiplier -def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: +def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> list: + """Perform MultiHeadAttetion opr forward. + + Args: + op (Operation): MultiHeadAttention + values (List[torch.Tensor]): opr inputs + ctx (TorchBackendContext, optional): Context. Defaults to None. + + Raises: + NotImplementedError: In [Vit Paper](https://arxiv.org/abs/2010.11929), MultiHeadAttention inputs are actually the same tensor, we suppose that this would **not** be simplified. + ValueError: MultiHeadAttention contains `embed_dim` and `num_heads`. + + Returns: + list: opr output and internal result for quantization. + """ if len(values) != 11: raise NotImplementedError('Not implement simplified MultiHeadAttention') @@ -345,7 +359,7 @@ def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: T raise ValueError('Cannot fetch embed_dim or num_heads') # setup parameters - batch_size = q.shape[0] + batch_size = q_in.shape[0] head_dim = embed_dim // num_heads scale = head_dim ** -0.5 @@ -359,8 +373,7 @@ def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: T feat = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim) out = F.linear(feat, o_w, o_b) - # return out, q, k, v, energy, feat - return out + return out, q, k, v, energy, feat def Add_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index ce81053f..53c4d3dd 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -95,7 +95,6 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): elif op.type in {'LayerNorm', 'Gelu'}: cfg = op.config.input_quantization_config[0] - print('cfg state {} algo {}'.format(cfg.state, cfg.observer_algorithm)) assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ and cfg.observer_algorithm in {'minmax', 'Minmax'} @@ -103,9 +102,43 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): item['zero_point'] = convert_value(cfg.offset, False, DataType.INT32) elif op.type == 'MultiHeadAttention': + # write input scale + cfg_q_in = op.config.input_quantization_config[0] + cfg_k_in = op.config.input_quantization_config[1] + cfg_v_in = op.config.input_quantization_config[2] - # TODO - continue + item['input_scale_q'] = convert_value(1.0 / cfg_q_in.scale, True, DataType.FP32) + item['input_scale_k'] = convert_value(1.0 / cfg_k_in.scale, True, DataType.FP32) + item['input_scale_v'] = convert_value(1.0 / cfg_v_in.scale, True, DataType.FP32) + + # write input/output weight scale, per-channel + cfg_q_w = op.config.input_quantization_config[3] + cfg_k_w = op.config.input_quantization_config[5] + cfg_v_w = op.config.input_quantization_config[7] + cfg_o_w = op.config.input_quantization_config[9] + + item['weight_q'] = convert_value(1 / cfg_q_w.scale, False, DataType.FP32) + item['weight_k'] = convert_value(1 / cfg_k_w.scale, False, DataType.FP32) + item['weight_v'] = convert_value(1 / cfg_v_w.scale, False, DataType.FP32) + item['weight_o'] = convert_value(1 / cfg_o_w.scale, False, DataType.FP32) + + # write internal scale + + import pdb + pdb.set_trace() + + cfg_q = op.config.output_quantization_config[1] + cfg_k = op.config.output_quantization_config[2] + cfg_v = op.config.output_quantization_config[3] + cfg_energy = op.config.output_quantization_config[4] + cfg_feat = op.config.output_quantization_config[5] + + item['internal_scale_q'] = convert_value(1.0 / cfg_q.scale, True, DataType.FP32) + item['internal_scale_k'] = convert_value(1.0 / cfg_k.scale, True, DataType.FP32) + item['internal_scale_v'] = convert_value(1.0 / cfg_v.scale, True, DataType.FP32) + item['internal_scale_energy'] = convert_value(1.0 / cfg_energy.scale, True, DataType.FP32) + item['internal_scale_feat'] = convert_value(1.0 / cfg_feat.scale, True, DataType.FP32) + else: print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name)) diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 2eba0d3f..556d7c9d 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -109,47 +109,59 @@ def init_quantize_config( pass elif operation.type == 'MultiHeadAttention': + # setup input quant param + input_policy = QuantizationPolicy( + QuantizationProperty.SYMMETRICAL + + QuantizationProperty.LINEAR + + QuantizationProperty.PER_TENSOR + ) + input_indexes = [0, 1, 2] + for index in input_indexes: + base_quant_config.input_quantization_config[index].policy = input_policy + base_quant_config.input_quantization_config[index].observer_algorithm = 'Minmax' + # setup weight quant param - fc_weight_config = base_quant_config.input_quantization_config[4] - fc_weight_config = QuantizationPolicy( + fc_weight_config = base_quant_config.input_quantization_config[3] + fc_weight_config.policy = QuantizationPolicy( QuantizationProperty.SYMMETRICAL + QuantizationProperty.LINEAR + QuantizationProperty.PER_CHANNEL ) - base_quant_config.input_quantization_config[3] = \ - ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( - convert_from = fc_weight_config, - offsets = None, scales = None, channel_axis = 0 - ) - base_quant_config.input_quantization_config[5] = \ - ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( - convert_from = fc_weight_config, - offsets = None, scales = None, channel_axis = 0 - ) - base_quant_config.input_quantization_config[7] = \ - ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( - convert_from = fc_weight_config, - offsets = None, scales = None, channel_axis = 0 - ) - base_quant_config.input_quantization_config[9] = \ - ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( - convert_from = fc_weight_config, - offsets = None, scales = None, channel_axis = 0 - ) - - internal_config = QuantizationPolicy( + + # setup qkv weight quant param + fc_weight_indexes = [3, 5, 7, 9] + for index in fc_weight_indexes: + base_quant_config.input_quantization_config[index] = \ + ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( + convert_from = fc_weight_config, + offsets = None, scales = None, channel_axis = 0 + ) + base_quant_config.input_quantization_config[index].observer_algorithm = 'Minmax' + + # bias not quant + bias_indexes = [4, 6, 8, 10] + for index in bias_indexes: + base_quant_config.input_quantization_config[index].state = QuantizationStates.FP32 + + # setup internal quant policy + # here be dragons: we treat internal result as opr output + internal_policy = QuantizationPolicy( QuantizationProperty.SYMMETRICAL + QuantizationProperty.LINEAR + QuantizationProperty.PER_TENSOR ) - base_quant_config.output_quantization_config[1] = internal_config - base_quant_config.output_quantization_config[2] = internal_config - base_quant_config.output_quantization_config[3] = internal_config - base_quant_config.output_quantization_config[4] = internal_config - base_quant_config.output_quantization_config[5] = internal_config + + internal_output_indexes = [1, 2, 3, 4, 5] + import pdb + pdb.set_trace() + for index in internal_output_indexes: + base_quant_config.output_quantization_config[index].policy = internal_policy + base_quant_config.output_quantization_config[index].observer_algorithm = 'Minmax' + # 显式说明输出不量化 - base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32 + if operation.type != 'MultiHeadAttention': + base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32 return base_quant_config @ property diff --git a/ppq/quantization/quantizer/__init__.py b/ppq/quantization/quantizer/__init__.py index 5744d2c4..cfcf8e5f 100644 --- a/ppq/quantization/quantizer/__init__.py +++ b/ppq/quantization/quantizer/__init__.py @@ -10,5 +10,5 @@ PPLCUDAMixPrecisionQuantizer, PPLCUDAQuantizer) from .TRTQuantizer import TensorRTQuantizer from .FPGAQuantizer import FPGAQuantizer -from . import NCNNQuantizer +from .NCNNQuantizer import NCNNQuantizer from .OpenvinoQuantizer import OpenvinoQuantizer \ No newline at end of file From ebad9a20ca8935a8bb18f931f060f832ab5337a5 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Sun, 12 Jun 2022 16:44:37 +0800 Subject: [PATCH 11/22] feat(format_graph): add FORMAT_MHA --- ppq/IR/base/command.py | 2 ++ ppq/IR/morph.py | 37 ++++++++++++++++++++++++++++---- ppq/api/interface.py | 2 ++ ppq/executor/op/torch/default.py | 1 - ppq/executor/torch.py | 4 ++++ ppq/parser/ncnn_exporter.py | 4 ---- 6 files changed, 41 insertions(+), 9 deletions(-) diff --git a/ppq/IR/base/command.py b/ppq/IR/base/command.py index 42a9ad2e..1627ba59 100644 --- a/ppq/IR/base/command.py +++ b/ppq/IR/base/command.py @@ -96,6 +96,8 @@ class GraphCommandType(Enum): FORMAT_SLICE = 29 # 从一个指定位置将图截断 TRUNCATE_ON_VAR = 30 + # 输出 MultiHeadAttention 中间过程,从而让 Quantizer 配置是否量化 + FORMAT_MHA = 31 class GraphCommand(): def __init__(self, command_type: GraphCommandType, **kwargs) -> None: diff --git a/ppq/IR/morph.py b/ppq/IR/morph.py index ff5a2dca..befec857 100644 --- a/ppq/IR/morph.py +++ b/ppq/IR/morph.py @@ -90,7 +90,8 @@ def _acceptable_command_types(self) -> List[GraphCommandType]: GraphCommandType.FORMAT_PARAMETERS, GraphCommandType.FORMAT_CONSTANT_INPUT, GraphCommandType.FORMAT_SLICE, - GraphCommandType.TRUNCATE_ON_VAR + GraphCommandType.TRUNCATE_ON_VAR, + GraphCommandType.FORMAT_MHA, ] def process(self, command: GraphCommand) -> Any: @@ -107,7 +108,7 @@ def process(self, command: GraphCommand) -> Any: if command.command_type == GraphCommandType.FORMAT_INT64_CONSTANT: return self.format_int64_constant() if command.command_type == GraphCommandType.REPLACE_SUB: - return self.replace_substarction() + return self.replace_substraction() if command.command_type == GraphCommandType.FORMAT_PARAMETERS: return self.format_parameter_variables() if command.command_type == GraphCommandType.FORMAT_CONSTANT_INPUT: @@ -117,6 +118,8 @@ def process(self, command: GraphCommand) -> Any: if command.command_type == GraphCommandType.TRUNCATE_ON_VAR: assert isinstance(command, TruncateGraphCommand), f'Use TruncateGraphCommand here.' return self.truncate_on_var(command.var, command.mark_as_output) + if command.command_type == GraphCommandType.FORMAT_MHA: + return self.format_mha() def format_slice(self) -> None: """ @@ -398,8 +401,34 @@ def format_parameter_variables(self) -> None: # pop variable from graph self.graph.remove_variable(var) - - def replace_substarction(self) -> None: + + def format_mha(self) -> None: + mha = [] + for operation in self.graph.operations.values(): + if operation.type == 'MultiHeadAttention': + mha.append(operation) + + for opr in mha: + assert isinstance(opr, Operation) + q_var = Variable(name=opr.name + '_fake_q_', source_op=opr) + k_var = Variable(name=opr.name + '_fake_k_', source_op=opr) + v_var = Variable(name=opr.name + '_fake_v_', source_op=opr) + energy_var = Variable(name=opr.name + '_fake_energy_', source_op=opr) + feat_var = Variable(name=opr.name + '_fake_feat_', source_op=opr) + + opr.outputs.append(q_var) + opr.outputs.append(k_var) + opr.outputs.append(v_var) + opr.outputs.append(energy_var) + opr.outputs.append(feat_var) + + self.graph.append_variable(q_var) + self.graph.append_variable(k_var) + self.graph.append_variable(v_var) + self.graph.append_variable(energy_var) + self.graph.append_variable(feat_var) + + def replace_substraction(self) -> None: substractions = [] for operation in self.graph.operations.values(): if operation.type == 'Sub': diff --git a/ppq/api/interface.py b/ppq/api/interface.py index ad9cf854..4ce638fb 100644 --- a/ppq/api/interface.py +++ b/ppq/api/interface.py @@ -617,6 +617,7 @@ def format_graph(graph: BaseGraph) -> BaseGraph: 在 PPQ 中,我们不希望出现 Constant 算子,所有 Constant 输入将被当作 parameter variable 连接到下游算子上 在 PPQ 中,我们不希望出现 Batchnorm 算子,所有 Batchnorm 将被合并 在 PPQ 中,我们不希望出现权重共享的算子,所有被共享的权重将被复制分裂成多份 + 在 PPQ 中,我们希望 MultiHeadAttention 算子中间过程可被量化 在 PPQ 中,我们不希望出现孤立算子,所有孤立算子将被移除 This function takes pre-processing procedure with your graph. @@ -639,6 +640,7 @@ def format_graph(graph: BaseGraph) -> BaseGraph: formatter(GraphCommand(GraphCommandType.FORMAT_CAST)) formatter(GraphCommand(GraphCommandType.FORMAT_SLICE)) formatter(GraphCommand(GraphCommandType.FORMAT_CLIP)) + formatter(GraphCommand(GraphCommandType.FORMAT_MHA)) formatter(GraphCommand(GraphCommandType.DELETE_ISOLATED)) return graph diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index 4b29bb80..29f75b70 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -372,7 +372,6 @@ def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: T feat = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim) out = F.linear(feat, o_w, o_b) - return out, q, k, v, energy, feat diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py index 7d145075..1b894c52 100644 --- a/ppq/executor/torch.py +++ b/ppq/executor/torch.py @@ -341,6 +341,10 @@ def __forward( operation_forward_func = platform_dispatching_table[operation.type] operation_runtime_hook = hooks[operation.name] if (hooks is not None) and (operation.name in hooks) else None inputs = [var.value for var in operation.inputs] + + if operation.name == 'MultiHeadAttention_31': + import pdb + pdb.set_trace() # if operation is an QuantableOperation, we have to quant its inputs and outputs at first. if isinstance(operation, QuantableOperation): diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 53c4d3dd..4f0ac218 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -123,10 +123,6 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): item['weight_o'] = convert_value(1 / cfg_o_w.scale, False, DataType.FP32) # write internal scale - - import pdb - pdb.set_trace() - cfg_q = op.config.output_quantization_config[1] cfg_k = op.config.output_quantization_config[2] cfg_v = op.config.output_quantization_config[3] From 9b128d9ef4caae4e5a7022cfe0eb113f39018c33 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Sun, 12 Jun 2022 21:27:39 +0800 Subject: [PATCH 12/22] style(ppq): clean code --- ppq/api/setting.py | 2 +- ppq/executor/torch.py | 4 ---- ppq/parser/ncnn_exporter.py | 2 -- ppq/quantization/quantizer/NCNNQuantizer.py | 2 -- 4 files changed, 1 insertion(+), 9 deletions(-) diff --git a/ppq/api/setting.py b/ppq/api/setting.py index f7cf0724..ac488b9d 100644 --- a/ppq/api/setting.py +++ b/ppq/api/setting.py @@ -164,7 +164,7 @@ class ActivationQuantizationSetting(): def __init__(self) -> None: # 激活值校准算法,不区分大小写,可以选择 minmax, kl, percentile, MSE # activation calibration method - self.calib_algorithm = 'minmax' + self.calib_algorithm = None # 执行逐层激活值校准,延长执行时间,提升精度 # whether to calibrate activation per - layer. diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py index 1b894c52..7d145075 100644 --- a/ppq/executor/torch.py +++ b/ppq/executor/torch.py @@ -341,10 +341,6 @@ def __forward( operation_forward_func = platform_dispatching_table[operation.type] operation_runtime_hook = hooks[operation.name] if (hooks is not None) and (operation.name in hooks) else None inputs = [var.value for var in operation.inputs] - - if operation.name == 'MultiHeadAttention_31': - import pdb - pdb.set_trace() # if operation is an QuantableOperation, we have to quant its inputs and outputs at first. if isinstance(operation, QuantableOperation): diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 4f0ac218..f14ae4d7 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -1,4 +1,3 @@ -from lib2to3.pytree import convert from typing import List from ppq.core import (DataType, NetworkFramework, QuantizationProperty, @@ -8,7 +7,6 @@ from .caffe_exporter import CaffeExporter from .onnx_exporter import OnnxExporter from .util import convert_value -from collections import Iterable class NCNNExporter(GraphExporter): diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 556d7c9d..029c6075 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -152,8 +152,6 @@ def init_quantize_config( ) internal_output_indexes = [1, 2, 3, 4, 5] - import pdb - pdb.set_trace() for index in internal_output_indexes: base_quant_config.output_quantization_config[index].policy = internal_policy base_quant_config.output_quantization_config[index].observer_algorithm = 'Minmax' From cdc4165f52ab2a3f2bfc7282cb0f27f20021fa47 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 14 Jun 2022 20:40:37 +0800 Subject: [PATCH 13/22] feat(toml): remove tail separator --- ppq/parser/ncnn_exporter.py | 61 +++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index f14ae4d7..612a11f5 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -8,6 +8,38 @@ from .onnx_exporter import OnnxExporter from .util import convert_value +import toml + +# rewrite toml encoder +class ArrayEncoder(toml.TomlEncoder): + + def __init__(self, _dict=dict, preserve=False, separator=","): + super(ArrayEncoder, self).__init__(_dict, preserve) + if separator.strip() == "": + separator = "," + separator + elif separator.strip(' \t\n\r,'): + raise ValueError("Invalid separator for arrays") + self.separator = separator + + def dump_list(self, v): + t = [] + retval = "[" + for u in v: + t.append(self.dump_value(u)) + while t != []: + s = [] + last = len(t) - 1 + for idx, u in enumerate(t): + if isinstance(u, list): + for r in u: + s.append(r) + elif idx != last: + retval += " " + str(u) + self.separator + else: + retval += " " + str(u) + " " + t = s + retval += "]" + return retval class NCNNExporter(GraphExporter): @@ -47,18 +79,19 @@ def export_raw_quant_config(self, config_path: str, graph: BaseGraph): def export_toml_quant_config(self, config_path: str, graph: BaseGraph): ''' toml is human readable format ''' - import toml - from ppq.core.config import PPQ_CONFIG - - table = {'source': '{} {}'.format(PPQ_CONFIG.NAME, PPQ_CONFIG.VERSION)} order = graph.topological_sort() - + table = {} for op in order: if hasattr(op, 'config'): item = dict() # avoiding Gather to Crop, we cannot deduce opr_type from opr_name item['type'] = op.type if op.type in {'Conv', 'Gemm'}: + input_cfg = op.config.input_quantization_config[0] + assert input_cfg.state == QuantizationStates.ACTIVATED and \ + input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) + item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32) + param_cfg = op.config.input_quantization_config[1] assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ @@ -71,11 +104,6 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): else: scale = param_cfg.scale item['weight'] = convert_value(1 / scale, False, DataType.FP32) - - input_cfg = op.config.input_quantization_config[0] - assert input_cfg.state == QuantizationStates.ACTIVATED and \ - input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) - item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32) elif op.type in {'Add'}: # Add may have multiple input node @@ -89,9 +117,15 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): zero_point.extend(convert_value(cfg.offset, False, DataType.INT32)) item['input_scale'] = input_scale - item['zero_point'] = zero_point - elif op.type in {'LayerNorm', 'Gelu'}: + elif op.type in {'Gelu'}: + cfg = op.config.input_quantization_config[0] + + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + and cfg.observer_algorithm in {'minmax', 'Minmax'} + item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) + + elif op.type in {'LayerNorm'}: cfg = op.config.input_quantization_config[0] assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ @@ -99,6 +133,7 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) item['zero_point'] = convert_value(cfg.offset, False, DataType.INT32) + elif op.type == 'MultiHeadAttention': # write input scale cfg_q_in = op.config.input_quantization_config[0] @@ -138,7 +173,7 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): table[op.name] = item - toml.dump(table, open(config_path, 'w+')) + toml.dump(table, open(config_path, 'w+'), encoder=ArrayEncoder()) def export_quantization_config(self, config_path: str, graph: BaseGraph): From 65970473922340f94f3e0783b93225dbd1438612 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Thu, 23 Jun 2022 21:35:10 +0800 Subject: [PATCH 14/22] fix(default.py): mha_forward error --- ppq/executor/op/torch/default.py | 12 ++- ppq/executor/torch.py | 4 +- ppq/parser/ncnn_exporter.py | 95 +++++++++++---------- ppq/quantization/quantizer/NCNNQuantizer.py | 6 +- 4 files changed, 65 insertions(+), 52 deletions(-) diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index 29f75b70..599585e5 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -363,9 +363,15 @@ def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: T head_dim = embed_dim // num_heads scale = head_dim ** -0.5 - q = F.linear(q_in, q_w, q_b) - k = F.linear(k_in, k_w, k_b) - v = F.linear(v_in, v_w, v_b) + B, N, _ = q.shape + + q_tmp = F.linear(q_in, q_w, q_b) + k_tmp = F.linear(k_in, k_w, k_b) + v_tmp = F.linear(v_in, v_w, v_b) + + q = q_tmp.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) + k = k_tmp.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) + v = v_tmp.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) energy = (q @ k.transpose(-2, -1)) * scale attn = energy.softmax(dim=-1) diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py index 7d145075..fdab7b1b 100644 --- a/ppq/executor/torch.py +++ b/ppq/executor/torch.py @@ -367,8 +367,8 @@ def __forward( outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs] fp_outputs = outputs - for out in fp_outputs: - print("{} output shape {}".format(operation.name, out.shape)) + # for out in fp_outputs: + # print("{} output shape {}".format(operation.name, out.shape)) # quantize all result if is necessary if isinstance(operation, QuantableOperation): diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 612a11f5..895ec618 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -77,7 +77,7 @@ def export_raw_quant_config(self, config_path: str, graph: BaseGraph): fd.close() - def export_toml_quant_config(self, config_path: str, graph: BaseGraph): + def export_ini_quant_config(self, config_path: str, graph: BaseGraph): ''' toml is human readable format ''' order = graph.topological_sort() table = {} @@ -86,55 +86,56 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): item = dict() # avoiding Gather to Crop, we cannot deduce opr_type from opr_name item['type'] = op.type - if op.type in {'Conv', 'Gemm'}: - input_cfg = op.config.input_quantization_config[0] - assert input_cfg.state == QuantizationStates.ACTIVATED and \ - input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) - item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32) + # if op.type in {'Conv', 'Gemm'}: + # input_cfg = op.config.input_quantization_config[0] + # assert input_cfg.state == QuantizationStates.ACTIVATED and \ + # input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) + # item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32) - param_cfg = op.config.input_quantization_config[1] - assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ - and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ - param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) - # a workaround for depthwise conv in ncnn - # will cause mis-alignment between ppq and ncnn - if op.type == 'Conv' and op.attributes.get('group', 1) > 1: - group = op.attributes.get('group', 1) - scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] - else: - scale = param_cfg.scale - item['weight'] = convert_value(1 / scale, False, DataType.FP32) - - elif op.type in {'Add'}: - # Add may have multiple input node - input_scale = [] - zero_point = [] + # param_cfg = op.config.input_quantization_config[1] + # assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ + # and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ + # param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) + # # a workaround for depthwise conv in ncnn + # # will cause mis-alignment between ppq and ncnn + # if op.type == 'Conv' and op.attributes.get('group', 1) > 1: + # group = op.attributes.get('group', 1) + # scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] + # else: + # scale = param_cfg.scale + # item['weight'] = convert_value(1 / scale, False, DataType.FP32) + + # elif op.type in {'Add'}: + # # Add may have multiple input node + # input_scale = [] + # zero_point = [] - for cfg in op.config.input_quantization_config: - assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED, QuantizationStates.SLAVE} \ - and cfg.observer_algorithm in {'minmax', 'Minmax'} - input_scale.append(convert_value(1.0 / cfg.scale, True, DataType.FP32)) - zero_point.extend(convert_value(cfg.offset, False, DataType.INT32)) + # for cfg in op.config.input_quantization_config: + # assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED, QuantizationStates.SLAVE} \ + # and cfg.observer_algorithm in {'minmax', 'Minmax'} + # input_scale.append(convert_value(1.0 / cfg.scale, True, DataType.FP32)) + # zero_point.extend(convert_value(cfg.offset, False, DataType.INT32)) - item['input_scale'] = input_scale + # item['input_scale'] = input_scale - elif op.type in {'Gelu'}: - cfg = op.config.input_quantization_config[0] + # elif op.type in {'Gelu'}: + # cfg = op.config.input_quantization_config[0] - assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ - and cfg.observer_algorithm in {'minmax', 'Minmax'} - item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) + # assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + # and cfg.observer_algorithm in {'minmax', 'Minmax'} + # item['input_scale'] = convert_value(1.0 / cfg.scale, True, DataType.FP32) - elif op.type in {'LayerNorm'}: - cfg = op.config.input_quantization_config[0] + # elif op.type in {'LayerNorm'}: + # cfg = op.config.input_quantization_config[0] - assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ - and cfg.observer_algorithm in {'minmax', 'Minmax'} - item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) - item['zero_point'] = convert_value(cfg.offset, False, DataType.INT32) + # assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + # and cfg.observer_algorithm in {'minmax', 'Minmax'} + # item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) + # item['zero_point'] = convert_value(cfg.offset, True, DataType.INT32) - elif op.type == 'MultiHeadAttention': + # el + if op.type == 'MultiHeadAttention': # write input scale cfg_q_in = op.config.input_quantization_config[0] cfg_k_in = op.config.input_quantization_config[1] @@ -167,19 +168,21 @@ def export_toml_quant_config(self, config_path: str, graph: BaseGraph): item['internal_scale_v'] = convert_value(1.0 / cfg_v.scale, True, DataType.FP32) item['internal_scale_energy'] = convert_value(1.0 / cfg_energy.scale, True, DataType.FP32) item['internal_scale_feat'] = convert_value(1.0 / cfg_feat.scale, True, DataType.FP32) - + + table[op.name] = item + else: print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name)) - table[op.name] = item + print("dump path {}".format(config_path)) toml.dump(table, open(config_path, 'w+'), encoder=ArrayEncoder()) def export_quantization_config(self, config_path: str, graph: BaseGraph): - toml_style = True - if toml_style: - self.export_toml_quant_config(config_path=config_path, graph=graph) + ini_style = True + if ini_style: + self.export_ini_quant_config(config_path=config_path, graph=graph) else: self.export_raw_quant_config(config_path=config_path, graph=graph) diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 029c6075..67978b02 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -172,10 +172,14 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: + # return { + # 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' + # } return { - 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' + 'MultiHeadAttention' } + @ property def quantize_policy(self) -> QuantizationPolicy: return QuantizationPolicy( From 99133cabe2a0ad08e6fb26aa3e8af9c0eb83cbc5 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 27 Jun 2022 15:17:16 +0800 Subject: [PATCH 15/22] improvement(NCNNQuantizer): clean code --- ppq/executor/op/torch/default.py | 17 ++-- ppq/parser/ncnn_exporter.py | 99 +++++++++++---------- ppq/quantization/quantizer/NCNNQuantizer.py | 5 +- ppq/quantization/quantizer/base.py | 9 -- 4 files changed, 60 insertions(+), 70 deletions(-) diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index 599585e5..fe846350 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -363,21 +363,22 @@ def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: T head_dim = embed_dim // num_heads scale = head_dim ** -0.5 - B, N, _ = q.shape - - q_tmp = F.linear(q_in, q_w, q_b) - k_tmp = F.linear(k_in, k_w, k_b) - v_tmp = F.linear(v_in, v_w, v_b) + xq = F.linear(q_in, q_w, q_b) + xk = F.linear(k_in, k_w, k_b) + xv = F.linear(v_in, v_w, v_b) + + B, N, _ = xq.shape - q = q_tmp.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) - k = k_tmp.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) - v = v_tmp.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) + q = xq.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) + k = xk.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) + v = xv.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3) energy = (q @ k.transpose(-2, -1)) * scale attn = energy.softmax(dim=-1) feat = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim) out = F.linear(feat, o_w, o_b) + return out, q, k, v, energy, feat diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 895ec618..3661c1a2 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -42,12 +42,14 @@ def dump_list(self, v): return retval class NCNNExporter(GraphExporter): - + ''' raw format only support Conv and Gemm quantization ''' def export_raw_quant_config(self, config_path: str, graph: BaseGraph): ''' ncnn table format when version <= 20220420 ''' fd = open(config_path, 'w+') topo_order = graph.topological_sort() for op in topo_order: + if op.type not in {'Conv', 'Gemm'}: + continue if op.is_computing_op and isinstance(op, QuantableOperation): fd.write(f'{op.name}_param_0 ') param_cfg = op.config.input_quantization_config[1] @@ -66,6 +68,8 @@ def export_raw_quant_config(self, config_path: str, graph: BaseGraph): fd.write('%f '% s) fd.write('\n') for op in topo_order: + if op.type not in {'Conv', 'Gemm'}: + continue if op.is_computing_op and isinstance(op, QuantableOperation): fd.write(f'{op.name} ') input_cfg = op.config.input_quantization_config[0] @@ -86,56 +90,54 @@ def export_ini_quant_config(self, config_path: str, graph: BaseGraph): item = dict() # avoiding Gather to Crop, we cannot deduce opr_type from opr_name item['type'] = op.type - # if op.type in {'Conv', 'Gemm'}: - # input_cfg = op.config.input_quantization_config[0] - # assert input_cfg.state == QuantizationStates.ACTIVATED and \ - # input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) - # item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32) + if op.type in {'Conv', 'Gemm'}: + input_cfg = op.config.input_quantization_config[0] + assert input_cfg.state == QuantizationStates.ACTIVATED and \ + input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) + item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32) - # param_cfg = op.config.input_quantization_config[1] - # assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ - # and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ - # param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) - # # a workaround for depthwise conv in ncnn - # # will cause mis-alignment between ppq and ncnn - # if op.type == 'Conv' and op.attributes.get('group', 1) > 1: - # group = op.attributes.get('group', 1) - # scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] - # else: - # scale = param_cfg.scale - # item['weight'] = convert_value(1 / scale, False, DataType.FP32) - - # elif op.type in {'Add'}: - # # Add may have multiple input node - # input_scale = [] - # zero_point = [] + param_cfg = op.config.input_quantization_config[1] + assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ + and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ + param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) + # a workaround for depthwise conv in ncnn + # will cause mis-alignment between ppq and ncnn + if op.type == 'Conv' and op.attributes.get('group', 1) > 1: + group = op.attributes.get('group', 1) + scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] + else: + scale = param_cfg.scale + item['weight'] = convert_value(1 / scale, False, DataType.FP32) + + elif op.type in {'Add'}: + # Add may have multiple input node + input_scale = [] + zero_point = [] - # for cfg in op.config.input_quantization_config: - # assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED, QuantizationStates.SLAVE} \ - # and cfg.observer_algorithm in {'minmax', 'Minmax'} - # input_scale.append(convert_value(1.0 / cfg.scale, True, DataType.FP32)) - # zero_point.extend(convert_value(cfg.offset, False, DataType.INT32)) + for cfg in op.config.input_quantization_config: + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED, QuantizationStates.SLAVE} \ + and cfg.observer_algorithm in {'minmax', 'Minmax'} + input_scale.append(convert_value(1.0 / cfg.scale, True, DataType.FP32)) + zero_point.extend(convert_value(cfg.offset, False, DataType.INT32)) - # item['input_scale'] = input_scale + item['input_scale'] = input_scale - # elif op.type in {'Gelu'}: - # cfg = op.config.input_quantization_config[0] + elif op.type in {'Gelu'}: + cfg = op.config.input_quantization_config[0] - # assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ - # and cfg.observer_algorithm in {'minmax', 'Minmax'} - # item['input_scale'] = convert_value(1.0 / cfg.scale, True, DataType.FP32) + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + and cfg.observer_algorithm in {'minmax', 'Minmax'} + item['input_scale'] = convert_value(1.0 / cfg.scale, True, DataType.FP32) - # elif op.type in {'LayerNorm'}: - # cfg = op.config.input_quantization_config[0] + elif op.type in {'LayerNorm'}: + cfg = op.config.input_quantization_config[0] - # assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ - # and cfg.observer_algorithm in {'minmax', 'Minmax'} - # item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) - # item['zero_point'] = convert_value(cfg.offset, True, DataType.INT32) + assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + and cfg.observer_algorithm in {'minmax', 'Minmax'} + item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) + item['zero_point'] = convert_value(cfg.offset, True, DataType.INT32) - - # el - if op.type == 'MultiHeadAttention': + elif op.type == 'MultiHeadAttention': # write input scale cfg_q_in = op.config.input_quantization_config[0] cfg_k_in = op.config.input_quantization_config[1] @@ -169,19 +171,18 @@ def export_ini_quant_config(self, config_path: str, graph: BaseGraph): item['internal_scale_energy'] = convert_value(1.0 / cfg_energy.scale, True, DataType.FP32) item['internal_scale_feat'] = convert_value(1.0 / cfg_feat.scale, True, DataType.FP32) - table[op.name] = item - else: print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name)) + continue + + table[op.name] = item - - print("dump path {}".format(config_path)) toml.dump(table, open(config_path, 'w+'), encoder=ArrayEncoder()) def export_quantization_config(self, config_path: str, graph: BaseGraph): - ini_style = True - if ini_style: + if config_path.endswith(".ini"): + print("export .ini format quant table, please make sure ncnn version >= 20220627") self.export_ini_quant_config(config_path=config_path, graph=graph) else: self.export_raw_quant_config(config_path=config_path, graph=graph) diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 67978b02..dc10523c 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -172,11 +172,8 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: - # return { - # 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' - # } return { - 'MultiHeadAttention' + 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' } diff --git a/ppq/quantization/quantizer/base.py b/ppq/quantization/quantizer/base.py index b8a547b8..7ca6fe09 100644 --- a/ppq/quantization/quantizer/base.py +++ b/ppq/quantization/quantizer/base.py @@ -232,19 +232,10 @@ def report(self) -> str: debug_str += '--------- Network Snapshot ---------\n' debug_str += f'Num of Op: [{len(self._graph.operations)}]\n' debug_str += f'Num of Quantized Op: [{len(quant_ops)}]\n' - # import pdb - # pdb.set_trace() - debug_str += " ".join([op.name for op in quant_ops]) - debug_str += '\n' debug_str += f'Num of Variable: [{len(self._graph.variables)}]\n' debug_str += f'Num of Quantized Var: [{len(quant_vars)}]\n' - debug_str += " ".join([var.name for var in quant_vars]) debug_str += '------- Quantization Snapshot ------\n' debug_str += f'Num of Quant Config: [{len(quant_cfgs)}]\n' - for state, cnt in config_states_cnt.items(): - if cnt <= 0: continue - padding_str = ' ' * max(28 - len(state.name), 0) - debug_str += f'{state.name}:{padding_str} [{cnt}]\n' return debug_str def build_quant_pipeline( From 410991ee58021bddfb99de3d88d8258db9c9a960 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 27 Jun 2022 20:48:01 +0800 Subject: [PATCH 16/22] fix(parser/onnx_exporter): remove mha fake output --- ppq/IR/morph.py | 2 +- ppq/parser/ncnn_exporter.py | 1 - ppq/parser/onnx_exporter.py | 11 ++++++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/ppq/IR/morph.py b/ppq/IR/morph.py index befec857..e244373f 100644 --- a/ppq/IR/morph.py +++ b/ppq/IR/morph.py @@ -401,7 +401,7 @@ def format_parameter_variables(self) -> None: # pop variable from graph self.graph.remove_variable(var) - + def format_mha(self) -> None: mha = [] for operation in self.graph.operations.values(): diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 3661c1a2..4c937922 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -187,7 +187,6 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph): else: self.export_raw_quant_config(config_path=config_path, graph=graph) - def export(self, file_path: str, graph: BaseGraph, config_path: str = None, input_shapes: List[List[int]] = [[1, 3, 224, 224]]): if config_path is not None: self.export_quantization_config(config_path, graph) diff --git a/ppq/parser/onnx_exporter.py b/ppq/parser/onnx_exporter.py index 4ce2af5a..149a70c6 100644 --- a/ppq/parser/onnx_exporter.py +++ b/ppq/parser/onnx_exporter.py @@ -1,3 +1,4 @@ +from ast import operator import json from typing import Union @@ -113,7 +114,7 @@ def export_operation(self, operation: Operation) -> onnx.OperatorProto: assert isinstance(exporter, OperationExporter), ( f'Expected an OpExporter here, however {type(exporter)} was given.') operation = exporter.export(operation=operation, graph=None) - + attributes = operation.attributes for key in attributes: value = attributes[key] @@ -165,7 +166,15 @@ def export_var(self, variable: Variable) -> onnx.TensorProto: dims=shape, vals=value) return tensor_proto + def remove_fake_node_output(self, graph: BaseGraph): + for opr in graph.topological_sort(): + if opr.type == 'MultiHeadAttention': + for var in opr.outputs[1:]: + graph.remove_variable(var) + def export(self, file_path: str, graph: BaseGraph, config_path: str = None): + self.remove_fake_node_output(graph) + # during export we will remove all boundary operations from graph. # we do not want to change the structure of original graph, # so there have to take a clone of it. From 114aa31b386a423c657065b167e4dad299deda89 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 28 Jun 2022 19:09:04 +0800 Subject: [PATCH 17/22] fix(optim/training): support gemm.rank=3 --- ppq/api/setting.py | 3 ++- ppq/quantization/observer/range.py | 4 ++-- ppq/quantization/optim/training.py | 9 +++++++-- ppq/quantization/qfunction/linear.py | 8 +++++++- ppq/quantization/quantizer/NCNNQuantizer.py | 8 +++++--- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/ppq/api/setting.py b/ppq/api/setting.py index ac488b9d..327af4c2 100644 --- a/ppq/api/setting.py +++ b/ppq/api/setting.py @@ -403,7 +403,7 @@ def __init__(self) -> None: self.advanced_optimization_setting = AdvancedOptimizationSetting() # 是否启动 bias correction pass - self.bias_correct = False + self.bias_correct = True self.bias_correct_setting = BiasCorrectionSetting() # 量化融合相关配置 @@ -443,6 +443,7 @@ def academic_setting() -> QuantizationSetting: @staticmethod def ncnn_setting() -> QuantizationSetting: default_setting = QuantizationSetting() + default_setting.bias_correct = True default_setting.fusion = False default_setting.dispatcher = 'pointwise' return default_setting diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index e799b085..bbb5e919 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -160,8 +160,8 @@ def render_quantization_config(self): # scale, offset here only deployed on cpu # we will move them towards target device through RunnableGraph - self._quant_cfg.scale = torch.tensor(scales, dtype=torch.float32, device=device) - self._quant_cfg.offset = torch.tensor(offsets, dtype=torch.float32, device=device) + self._quant_cfg.scale = torch.tensor(scales.clone().detach(), dtype=torch.float32, device=device) + self._quant_cfg.offset = torch.tensor(offsets.clone().detach(), dtype=torch.float32, device=device) self._quant_cfg.state = QuantizationStates.ACTIVATED else: raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.') diff --git a/ppq/quantization/optim/training.py b/ppq/quantization/optim/training.py index dd28f0ee..ad4d5eee 100644 --- a/ppq/quantization/optim/training.py +++ b/ppq/quantization/optim/training.py @@ -347,7 +347,7 @@ class BiasCorrectionPass(TrainingBasedPass): def __init__(self, auto_check: bool=False, interested_output: List[str] = None, verbose: bool = True, max_steps:int = 8) -> None: """Quantization can introduce a biased error in the activations. Bias - correction serves as a useful prosedure to eliminate those introduced + correction serves as a useful procedure to eliminate those introduced bias error. let: Y = WX + b @@ -379,7 +379,12 @@ def collect_bias(output: torch.Tensor, collector: list, op_type: str): if op_type in {'Conv', 'ConvTranspose'}: collector.append(torch.mean(output, dim=(0, 2, 3)).unsqueeze(0)) elif op_type in {'Gemm'}: - collector.append(torch.mean(output, dim=(0, )).unsqueeze(0)) + if len(output.shape) == 2: + collector.append(torch.mean(output, dim=(0, )).unsqueeze(0)) + elif len(output.shape) == 3: + collector.append(torch.mean(output, dim=(0,1)).unsqueeze(0)) + else: + raise ValueError(f'Unsupported Gemm shape: {output.shape}') else: raise TypeError(f'Unsupported Operation type: {op_type}') assert isinstance(executor, TorchExecutor), ( diff --git a/ppq/quantization/qfunction/linear.py b/ppq/quantization/qfunction/linear.py index 7ae596a2..aaede2b7 100644 --- a/ppq/quantization/qfunction/linear.py +++ b/ppq/quantization/qfunction/linear.py @@ -18,7 +18,6 @@ def __init__(self) -> None: def __call__(self, input_tensor: Any, quantization_config: TensorQuantizationConfig, **kwargs) -> Any: return super().__call__(input_tensor, quantization_config, **kwargs) - if not PPQ_CONFIG.USING_CUDA_KERNEL: class TensorwiseLinearQuantImpl(Function): """Torch Tensorwise quantize is designed to quantize a torch Tensor @@ -180,6 +179,13 @@ def PPQLinearQuantFunction( tensor, config.scale, config.offset, config.channel_axis, config.quant_min, config.quant_max, config.rounding, dropout) + elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): + if config.channel_axis != 2: + raise ValueError('opr using BNC format while channel_axis != 2') + return ChannelwiseLinearQuantImpl.apply( + tensor, config.scale, config.offset, config.channel_axis, + config.quant_min, config.quant_max, config.rounding, + dropout) elif config.policy.has_property(QuantizationProperty.PER_TENSOR): return TensorwiseLinearQuantImpl.apply( tensor, config.scale, config.offset, diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index dc10523c..fe1a8cb0 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -94,7 +94,7 @@ def init_quantize_config( base_quant_config.input_quantization_config[0] = \ ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( convert_from = inp_config, - offsets = None, scales = None, channel_axis = 1 + offsets = None, scales = None, channel_axis = 2 ) base_quant_config.input_quantization_config[0].observer_algorithm = 'Minmax' @@ -172,11 +172,13 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: + # return { + # 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' + # } return { - 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' + 'Add', 'Conv', 'MultiHeadAttention', 'Gemm' } - @ property def quantize_policy(self) -> QuantizationPolicy: return QuantizationPolicy( From 165a1f470da34c3b45efd974eac3e436e05e6b66 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 29 Jun 2022 13:47:53 +0800 Subject: [PATCH 18/22] improvement(ppq): clean code --- ppq/api/setting.py | 7 +++++-- ppq/executor/torch.py | 5 ----- ppq/parser/ncnn_exporter.py | 2 +- ppq/parser/util.py | 2 -- ppq/quantization/quantizer/base.py | 4 ++++ 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ppq/api/setting.py b/ppq/api/setting.py index 327af4c2..22c52524 100644 --- a/ppq/api/setting.py +++ b/ppq/api/setting.py @@ -164,7 +164,7 @@ class ActivationQuantizationSetting(): def __init__(self) -> None: # 激活值校准算法,不区分大小写,可以选择 minmax, kl, percentile, MSE # activation calibration method - self.calib_algorithm = None + self.calib_algorithm = 'percentile' # 执行逐层激活值校准,延长执行时间,提升精度 # whether to calibrate activation per - layer. @@ -403,7 +403,7 @@ def __init__(self) -> None: self.advanced_optimization_setting = AdvancedOptimizationSetting() # 是否启动 bias correction pass - self.bias_correct = True + self.bias_correct = False self.bias_correct_setting = BiasCorrectionSetting() # 量化融合相关配置 @@ -446,6 +446,9 @@ def ncnn_setting() -> QuantizationSetting: default_setting.bias_correct = True default_setting.fusion = False default_setting.dispatcher = 'pointwise' + + default_setting.quantize_activation_setting.calib_algorithm = None + return default_setting @ staticmethod diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py index fdab7b1b..eae53326 100644 --- a/ppq/executor/torch.py +++ b/ppq/executor/torch.py @@ -367,9 +367,6 @@ def __forward( outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs] fp_outputs = outputs - # for out in fp_outputs: - # print("{} output shape {}".format(operation.name, out.shape)) - # quantize all result if is necessary if isinstance(operation, QuantableOperation): output_configs = [_ for _ in operation.config.output_quantization_config] @@ -394,8 +391,6 @@ def __forward( result_collector[output_names.index(output_var.name)] = outputs[output_idx] except Exception as e: - import pdb - pdb.set_trace() raise RuntimeError(f'Error happens when dealing with operation {str(operation)}, {str(e)}') # remove useless value(runtime clear). diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 4c937922..1ba76f91 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -44,7 +44,7 @@ def dump_list(self, v): class NCNNExporter(GraphExporter): ''' raw format only support Conv and Gemm quantization ''' def export_raw_quant_config(self, config_path: str, graph: BaseGraph): - ''' ncnn table format when version <= 20220420 ''' + ''' ncnn table format when version <= 20220629 ''' fd = open(config_path, 'w+') topo_order = graph.topological_sort() for op in topo_order: diff --git a/ppq/parser/util.py b/ppq/parser/util.py index 232003a1..17c7c928 100644 --- a/ppq/parser/util.py +++ b/ppq/parser/util.py @@ -8,7 +8,6 @@ def convert_value( value: Union[int, float, np.ndarray, torch.Tensor], export_as_float: bool, dtype: DataType = DataType.FP32) -> Union[float, list]: - """Converting value from any to python native data dtype, ready for export. Args: @@ -19,7 +18,6 @@ def convert_value( Returns: Union[float, list]: Converted value """ - if dtype not in {DataType.FP32, DataType.INT32}: raise ValueError(f'Can Only export dtype fp32 and int32, ' f'while you are requiring to dump a {dtype.name} value') diff --git a/ppq/quantization/quantizer/base.py b/ppq/quantization/quantizer/base.py index 7ca6fe09..491f6b45 100644 --- a/ppq/quantization/quantizer/base.py +++ b/ppq/quantization/quantizer/base.py @@ -236,6 +236,10 @@ def report(self) -> str: debug_str += f'Num of Quantized Var: [{len(quant_vars)}]\n' debug_str += '------- Quantization Snapshot ------\n' debug_str += f'Num of Quant Config: [{len(quant_cfgs)}]\n' + for state, cnt in config_states_cnt.items(): + if cnt <= 0: continue + padding_str = ' ' * max(28 - len(state.name), 0) + debug_str += f'{state.name}:{padding_str} [{cnt}]\n' return debug_str def build_quant_pipeline( From 988980d20870288380905b2ed90f24404bf03524 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 9 Aug 2022 18:26:33 +0800 Subject: [PATCH 19/22] feat(vit): support add and layernorm --- ppq/api/setting.py | 2 +- ppq/parser/ncnn_exporter.py | 27 +++++++++------------ ppq/quantization/observer/range.py | 20 +++++++-------- ppq/quantization/quantizer/NCNNQuantizer.py | 18 ++++++++++---- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/ppq/api/setting.py b/ppq/api/setting.py index 22c52524..f8761726 100644 --- a/ppq/api/setting.py +++ b/ppq/api/setting.py @@ -443,7 +443,7 @@ def academic_setting() -> QuantizationSetting: @staticmethod def ncnn_setting() -> QuantizationSetting: default_setting = QuantizationSetting() - default_setting.bias_correct = True + # default_setting.bias_correct = True default_setting.fusion = False default_setting.dispatcher = 'pointwise' diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 1ba76f91..a8862429 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -111,16 +111,8 @@ def export_ini_quant_config(self, config_path: str, graph: BaseGraph): elif op.type in {'Add'}: # Add may have multiple input node - input_scale = [] - zero_point = [] - - for cfg in op.config.input_quantization_config: - assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED, QuantizationStates.SLAVE} \ - and cfg.observer_algorithm in {'minmax', 'Minmax'} - input_scale.append(convert_value(1.0 / cfg.scale, True, DataType.FP32)) - zero_point.extend(convert_value(cfg.offset, False, DataType.INT32)) - - item['input_scale'] = input_scale + cfg_out = op.config.output_quantization_config[0] + item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32) elif op.type in {'Gelu'}: cfg = op.config.input_quantization_config[0] @@ -130,12 +122,15 @@ def export_ini_quant_config(self, config_path: str, graph: BaseGraph): item['input_scale'] = convert_value(1.0 / cfg.scale, True, DataType.FP32) elif op.type in {'LayerNorm'}: - cfg = op.config.input_quantization_config[0] - - assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ - and cfg.observer_algorithm in {'minmax', 'Minmax'} - item['input_scale'] = convert_value(1.0 / cfg.scale, False, DataType.FP32) - item['zero_point'] = convert_value(cfg.offset, True, DataType.INT32) + cfg_in = op.config.input_quantization_config[0] + cfg_out = op.config.output_quantization_config[0] + + assert cfg_in.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + and cfg_in.observer_algorithm in {'minmax', 'Minmax'} \ + and cfg_out.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} + + item['input_scales'] = convert_value(1.0 / cfg_in.scale, False, DataType.FP32) + item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32) elif op.type == 'MultiHeadAttention': # write input scale diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index bbb5e919..a44c1a9c 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -325,18 +325,16 @@ def observe(self, value: torch.Tensor): self._percentile_collector.append(CUDA.Quantile(value, self._percentile).view(1, -1)) elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or \ self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): - import pdb - pdb.set_trace() - raise PermissionError('Percentile observer can not deal with per channel quantization.') + # raise PermissionError('Percentile observer can not deal with per channel quantization.') - # assert isinstance(self._quant_cfg, ChannelwiseTensorQuantizationConfig), ( - # 'Your quantization config has PER_CHAN=NEL while it is not a ' - # 'ChannelwiseTensorQuantizationConfig instance.') - # channel_axis = self._quant_cfg.channel_axis - # channelwise_view = value.transpose(dim0=0, dim1=channel_axis) - # channelwise_view = torch.flatten(channelwise_view, start_dim=1) - # self._percentile_mins.append(-torch.quantile(-channelwise_view, q=self._percentile, dim=1, keepdim=True)[0]) - # self._percentile_maxs.append(torch.quantile(channelwise_view, q=self._percentile, dim=1, keepdimTrue)[0]) + assert isinstance(self._quant_cfg, ChannelwiseTensorQuantizationConfig), ( + 'Your quantization config has PER_CHAN=NEL while it is not a ' + 'ChannelwiseTensorQuantizationConfig instance.') + channel_axis = self._quant_cfg.channel_axis + channelwise_view = value.transpose(dim0=0, dim1=channel_axis) + channelwise_view = torch.flatten(channelwise_view, start_dim=1) + self._percentile_mins.append(-torch.quantile(-channelwise_view, q=self._percentile, dim=1, keepdim=True)[0]) + self._percentile_maxs.append(torch.quantile(channelwise_view, q=self._percentile, dim=1, keepdim=True)[0]) else: raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.') diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index fe1a8cb0..ab50ca31 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -104,8 +104,18 @@ def init_quantize_config( wconfig.state = QuantizationStates.FP32 bconfig.state = QuantizationStates.FP32 + # 输出量化 + output_policy = QuantizationPolicy( + QuantizationProperty.SYMMETRICAL + + QuantizationProperty.LINEAR + + QuantizationProperty.PER_TENSOR + ) + base_quant_config.output_quantization_config[0].policy = output_policy + base_quant_config.output_quantization_config[0].observer_algorithm = 'Minmax' + elif operation.type in {'Add'}: # use default param + pass elif operation.type == 'MultiHeadAttention': @@ -158,7 +168,7 @@ def init_quantize_config( # 显式说明输出不量化 - if operation.type != 'MultiHeadAttention': + if operation.type not in {'MultiHeadAttention', 'LayerNorm', 'Add'}: base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32 return base_quant_config @@ -172,11 +182,9 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: - # return { - # 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu' - # } return { - 'Add', 'Conv', 'MultiHeadAttention', 'Gemm' + 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm' + # 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm' } @ property From 2a55df368fa3829affc55d2713c66bc57f446cb5 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 10 Aug 2022 20:02:01 +0800 Subject: [PATCH 20/22] feat(NCNNQuantizer): update binaryop config --- ppq/parser/ncnn_exporter.py | 8 ++++++++ ppq/quantization/quantizer/NCNNQuantizer.py | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index a8862429..281bee5d 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -1,3 +1,4 @@ +from lib2to3.pytree import convert from typing import List from ppq.core import (DataType, NetworkFramework, QuantizationProperty, @@ -111,6 +112,13 @@ def export_ini_quant_config(self, config_path: str, graph: BaseGraph): elif op.type in {'Add'}: # Add may have multiple input node + input_scales = [] + for cfg_in in op.config.input_quantization_config: + assert cfg_in.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + and cfg_in.observer_algorithm in {'minmax', 'Minmax'} + input_scales.append(convert_value(1.0 / cfg_in.scale, True, DataType.FP32)) + item['input_scales'] = input_scales + cfg_out = op.config.output_quantization_config[0] item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32) diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index ab50ca31..679c9269 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -115,7 +115,6 @@ def init_quantize_config( elif operation.type in {'Add'}: # use default param - pass elif operation.type == 'MultiHeadAttention': From 76713e4da3b40f9ec3eadeaa7e54639f5eb5edf7 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Thu, 18 Aug 2022 11:51:15 +0800 Subject: [PATCH 21/22] update --- ppq/api/setting.py | 2 +- ppq/parser/ncnn_exporter.py | 22 ++++++++++----------- ppq/quantization/quantizer/NCNNQuantizer.py | 6 +----- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/ppq/api/setting.py b/ppq/api/setting.py index f8761726..22c52524 100644 --- a/ppq/api/setting.py +++ b/ppq/api/setting.py @@ -443,7 +443,7 @@ def academic_setting() -> QuantizationSetting: @staticmethod def ncnn_setting() -> QuantizationSetting: default_setting = QuantizationSetting() - # default_setting.bias_correct = True + default_setting.bias_correct = True default_setting.fusion = False default_setting.dispatcher = 'pointwise' diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 281bee5d..582c7441 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -110,17 +110,17 @@ def export_ini_quant_config(self, config_path: str, graph: BaseGraph): scale = param_cfg.scale item['weight'] = convert_value(1 / scale, False, DataType.FP32) - elif op.type in {'Add'}: - # Add may have multiple input node - input_scales = [] - for cfg_in in op.config.input_quantization_config: - assert cfg_in.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ - and cfg_in.observer_algorithm in {'minmax', 'Minmax'} - input_scales.append(convert_value(1.0 / cfg_in.scale, True, DataType.FP32)) - item['input_scales'] = input_scales - - cfg_out = op.config.output_quantization_config[0] - item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32) + # elif op.type in {'Add'}: + # # Add may have multiple input node + # input_scales = [] + # for cfg_in in op.config.input_quantization_config: + # assert cfg_in.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ + # and cfg_in.observer_algorithm in {'minmax', 'Minmax'} + # input_scales.append(convert_value(1.0 / cfg_in.scale, True, DataType.FP32)) + # item['input_scales'] = input_scales + + # cfg_out = op.config.output_quantization_config[0] + # item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32) elif op.type in {'Gelu'}: cfg = op.config.input_quantization_config[0] diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py index 679c9269..9f9fba35 100644 --- a/ppq/quantization/quantizer/NCNNQuantizer.py +++ b/ppq/quantization/quantizer/NCNNQuantizer.py @@ -113,10 +113,6 @@ def init_quantize_config( base_quant_config.output_quantization_config[0].policy = output_policy base_quant_config.output_quantization_config[0].observer_algorithm = 'Minmax' - elif operation.type in {'Add'}: - # use default param - pass - elif operation.type == 'MultiHeadAttention': # setup input quant param input_policy = QuantizationPolicy( @@ -182,7 +178,7 @@ def default_platform(self) -> TargetPlatform: @ property def quant_operation_types(self) -> set: return { - 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm' + 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm' # 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm' } From 314273b17be0545d6bf18b2c924c0b390e56d1e2 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Fri, 19 Aug 2022 17:51:06 +0800 Subject: [PATCH 22/22] update --- ppq/parser/ncnn_exporter.py | 1 - ppq/quantization/observer/range.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 582c7441..ba4aac64 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -136,7 +136,6 @@ def export_ini_quant_config(self, config_path: str, graph: BaseGraph): assert cfg_in.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \ and cfg_in.observer_algorithm in {'minmax', 'Minmax'} \ and cfg_out.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} - item['input_scales'] = convert_value(1.0 / cfg_in.scale, False, DataType.FP32) item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32) diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index a44c1a9c..73deff82 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -355,9 +355,9 @@ def render_quantization_config(self): self._quant_cfg.state = QuantizationStates.ACTIVATED elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or \ self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC): - import pdb - pdb.set_trace() - raise PermissionError('Percentile observer can not deal with per channel quantization.') + # import pdb + # pdb.set_trace() + # raise PermissionError('Percentile observer can not deal with per channel quantization.') if len(self._percentile_maxs) == 0: raise ValueError('Can not render quantization config yet, Observer data collator is empty. ' \ 'Invoke observe() function before render config.')