diff --git a/doc/pages/instructions/ppq_quant_1.html b/doc/pages/instructions/ppq_quant_1.html
index f674d985..486f8126 100644
--- a/doc/pages/instructions/ppq_quant_1.html
+++ b/doc/pages/instructions/ppq_quant_1.html
@@ -183,7 +183,8 @@
QuantizationPolicy
QuantizationPolicy 在 PPQ 中用来描述量化策略,它是一些 QuantizationProperty 枚举的组合位图。在 PPQ 中我们支持的 QuantizationProperty 包括:
- PER_TENSOR:逐层量化。
- - PER_CHANNEL:逐通道量化。
+ - PER_CHANNEL:CNN 模型逐通道量化。
+ - PER_CHANNEL_BNC:Transformer 模型逐通道量化。
- LINEAR: 线性量化。
- EXPONENTIAL: 指数量化。
- SYMMETRICAL: 对称量化。
diff --git a/ppq/IR/base/command.py b/ppq/IR/base/command.py
index 42a9ad2e..1627ba59 100644
--- a/ppq/IR/base/command.py
+++ b/ppq/IR/base/command.py
@@ -96,6 +96,8 @@ class GraphCommandType(Enum):
FORMAT_SLICE = 29
# 从一个指定位置将图截断
TRUNCATE_ON_VAR = 30
+ # 输出 MultiHeadAttention 中间过程,从而让 Quantizer 配置是否量化
+ FORMAT_MHA = 31
class GraphCommand():
def __init__(self, command_type: GraphCommandType, **kwargs) -> None:
diff --git a/ppq/IR/deploy.py b/ppq/IR/deploy.py
index 2316a292..7dbd8a05 100644
--- a/ppq/IR/deploy.py
+++ b/ppq/IR/deploy.py
@@ -68,7 +68,7 @@ def retrieve(self):
for _, variable in self._graph.variables.items():
assert isinstance(variable, Variable), \
f'Failed to send graph to device, incorrect variable {variable} found.'
- variable.value = convert_any_to_numpy(variable.value, accepet_none=True)
+ variable.value = convert_any_to_numpy(variable.value, accept_none=True)
return self
@@ -84,12 +84,12 @@ def deploy(self, device: str):
if operator.type == 'Constant' and operator.platform != TargetPlatform.SHAPE_OR_INDEX:
operator.attributes['value'] = \
convert_any_to_torch_tensor(
- operator.attributes['value'], accepet_none=False).to(device)
+ operator.attributes['value'], accept_none=False).to(device)
if operator.type == 'Constant' and operator.platform == TargetPlatform.SHAPE_OR_INDEX:
value = operator.attributes['value']
operator.attributes['value'] = convert_any_to_torch_tensor(
- value, accepet_none=False, device='cpu')
+ value, accept_none=False, device='cpu')
for _, variable in self._graph.variables.items():
assert isinstance(variable, Variable), \
@@ -108,10 +108,10 @@ def deploy(self, device: str):
# if all downstream operations are shape related operations, send value to cpu
if platform == TargetPlatform.SHAPE_OR_INDEX:
variable.value = convert_any_to_torch_tensor(
- variable.value, accepet_none=True).to('cpu')
+ variable.value, accept_none=True).to('cpu')
else:
variable.value = convert_any_to_torch_tensor(
- variable.value, accepet_none=True).to(device=device)
+ variable.value, accept_none=True).to(device=device)
# if variable is a shape-related variable, send it to cpu.
if variable.is_parameter:
@@ -125,5 +125,5 @@ def deploy(self, device: str):
if dest_op.type in {'Reshape', 'Slice', 'Gather', 'Pad', 'Resize', 'Split', 'TopK', 'Tile', 'Expand'}:
if dest_idx >= 1 and len(variable.dest_ops) == 1:
variable.value = convert_any_to_torch_tensor(
- variable.value, accepet_none=True).to('cpu')
+ variable.value, accept_none=True).to('cpu')
return self
diff --git a/ppq/IR/morph.py b/ppq/IR/morph.py
index ff5a2dca..e244373f 100644
--- a/ppq/IR/morph.py
+++ b/ppq/IR/morph.py
@@ -90,7 +90,8 @@ def _acceptable_command_types(self) -> List[GraphCommandType]:
GraphCommandType.FORMAT_PARAMETERS,
GraphCommandType.FORMAT_CONSTANT_INPUT,
GraphCommandType.FORMAT_SLICE,
- GraphCommandType.TRUNCATE_ON_VAR
+ GraphCommandType.TRUNCATE_ON_VAR,
+ GraphCommandType.FORMAT_MHA,
]
def process(self, command: GraphCommand) -> Any:
@@ -107,7 +108,7 @@ def process(self, command: GraphCommand) -> Any:
if command.command_type == GraphCommandType.FORMAT_INT64_CONSTANT:
return self.format_int64_constant()
if command.command_type == GraphCommandType.REPLACE_SUB:
- return self.replace_substarction()
+ return self.replace_substraction()
if command.command_type == GraphCommandType.FORMAT_PARAMETERS:
return self.format_parameter_variables()
if command.command_type == GraphCommandType.FORMAT_CONSTANT_INPUT:
@@ -117,6 +118,8 @@ def process(self, command: GraphCommand) -> Any:
if command.command_type == GraphCommandType.TRUNCATE_ON_VAR:
assert isinstance(command, TruncateGraphCommand), f'Use TruncateGraphCommand here.'
return self.truncate_on_var(command.var, command.mark_as_output)
+ if command.command_type == GraphCommandType.FORMAT_MHA:
+ return self.format_mha()
def format_slice(self) -> None:
"""
@@ -398,8 +401,34 @@ def format_parameter_variables(self) -> None:
# pop variable from graph
self.graph.remove_variable(var)
-
- def replace_substarction(self) -> None:
+
+ def format_mha(self) -> None:
+ mha = []
+ for operation in self.graph.operations.values():
+ if operation.type == 'MultiHeadAttention':
+ mha.append(operation)
+
+ for opr in mha:
+ assert isinstance(opr, Operation)
+ q_var = Variable(name=opr.name + '_fake_q_', source_op=opr)
+ k_var = Variable(name=opr.name + '_fake_k_', source_op=opr)
+ v_var = Variable(name=opr.name + '_fake_v_', source_op=opr)
+ energy_var = Variable(name=opr.name + '_fake_energy_', source_op=opr)
+ feat_var = Variable(name=opr.name + '_fake_feat_', source_op=opr)
+
+ opr.outputs.append(q_var)
+ opr.outputs.append(k_var)
+ opr.outputs.append(v_var)
+ opr.outputs.append(energy_var)
+ opr.outputs.append(feat_var)
+
+ self.graph.append_variable(q_var)
+ self.graph.append_variable(k_var)
+ self.graph.append_variable(v_var)
+ self.graph.append_variable(energy_var)
+ self.graph.append_variable(feat_var)
+
+ def replace_substraction(self) -> None:
substractions = []
for operation in self.graph.operations.values():
if operation.type == 'Sub':
diff --git a/ppq/api/interface.py b/ppq/api/interface.py
index 3570b5c5..4ce638fb 100644
--- a/ppq/api/interface.py
+++ b/ppq/api/interface.py
@@ -617,6 +617,7 @@ def format_graph(graph: BaseGraph) -> BaseGraph:
在 PPQ 中,我们不希望出现 Constant 算子,所有 Constant 输入将被当作 parameter variable 连接到下游算子上
在 PPQ 中,我们不希望出现 Batchnorm 算子,所有 Batchnorm 将被合并
在 PPQ 中,我们不希望出现权重共享的算子,所有被共享的权重将被复制分裂成多份
+ 在 PPQ 中,我们希望 MultiHeadAttention 算子中间过程可被量化
在 PPQ 中,我们不希望出现孤立算子,所有孤立算子将被移除
This function takes pre-processing procedure with your graph.
@@ -639,6 +640,7 @@ def format_graph(graph: BaseGraph) -> BaseGraph:
formatter(GraphCommand(GraphCommandType.FORMAT_CAST))
formatter(GraphCommand(GraphCommandType.FORMAT_SLICE))
formatter(GraphCommand(GraphCommandType.FORMAT_CLIP))
+ formatter(GraphCommand(GraphCommandType.FORMAT_MHA))
formatter(GraphCommand(GraphCommandType.DELETE_ISOLATED))
return graph
@@ -657,6 +659,7 @@ def dispatch_graph(graph: BaseGraph, platform: TargetPlatform, setting: Quantiza
"""
assert platform in QUANTIZER_COLLECTION, (
f'Platform misunderstood, except one of following platform {QUANTIZER_COLLECTION.keys()}')
+
quantizer = QUANTIZER_COLLECTION[platform](graph) # 初始化一个 quantizer 没有很大代价...
if str(setting.dispatcher).lower() not in DISPATCHER_TABLE:
diff --git a/ppq/api/setting.py b/ppq/api/setting.py
index dbb393a4..22c52524 100644
--- a/ppq/api/setting.py
+++ b/ppq/api/setting.py
@@ -443,8 +443,12 @@ def academic_setting() -> QuantizationSetting:
@staticmethod
def ncnn_setting() -> QuantizationSetting:
default_setting = QuantizationSetting()
+ default_setting.bias_correct = True
default_setting.fusion = False
default_setting.dispatcher = 'pointwise'
+
+ default_setting.quantize_activation_setting.calib_algorithm = None
+
return default_setting
@ staticmethod
diff --git a/ppq/core/common.py b/ppq/core/common.py
index 5444a0bb..5ee73d96 100644
--- a/ppq/core/common.py
+++ b/ppq/core/common.py
@@ -29,7 +29,7 @@
LINEAR_ACTIVATIONS = {'Relu', 'Clip'}
# COPUTING OP 是所有计算层,该属性被用于联合定点和子图切分
-COMPUTING_OP = {'Conv', 'Gemm', 'ConvTranspose', 'MatMul'}
+COMPUTING_OP = {'Conv', 'Gemm', 'ConvTranspose', 'MatMul', 'LayerNorm', 'MultiHeadAttention'}
# SOI OP 是所有产生 SOI 输出的节点类型,该属性被用于子图切分
SOI_OP = {'TopK', 'Shape', 'NonMaxSuppression'}
# 强制联合定点的算子种类
diff --git a/ppq/core/data.py b/ppq/core/data.py
index ff4dee43..d171d366 100644
--- a/ppq/core/data.py
+++ b/ppq/core/data.py
@@ -212,20 +212,20 @@ def num_of_output(self):
def convert_any_to_python_primary_type(
x: Union[torch.Tensor, np.ndarray, int, float, list, str],
- accepet_none: bool=True) -> Union[int, float, list, str]:
- if x is None and accepet_none: return None
- if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.')
+ accept_none: bool=True) -> Union[int, float, list, str]:
+ if x is None and accept_none: return None
+ if x is None and not accept_none: raise ValueError('Trying to convert an empty value.')
if isinstance(x, list) or isinstance(x, tuple): return list(x)
elif isinstance(x, int) or isinstance(x, float): return x
elif isinstance(x, torch.Tensor):
- if x.numel() == 0 and accepet_none: return None
- if x.numel() == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.')
+ if x.numel() == 0 and accept_none: return None
+ if x.numel() == 0 and not accept_none: raise ValueError('Trying to convert an empty value.')
if str(x.device) != 'cpu': x = x.cpu()
if x.numel() == 1: return x.item()
if x.numel() > 1: return x.tolist()
elif isinstance(x, np.ndarray):
- if x.size == 0 and accepet_none: return None
- if x.size == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.')
+ if x.size == 0 and accept_none: return None
+ if x.size == 0 and not accept_none: raise ValueError('Trying to convert an empty value.')
if x.size == 1: return x.reshape((1, )).tolist()[0]
if x.size > 1: return x.tolist()
elif isinstance(x, str):
@@ -236,14 +236,14 @@ def convert_any_to_python_primary_type(
def convert_any_to_numpy(
x: Union[torch.Tensor, np.ndarray, int, float, list, tuple],
- accepet_none: bool=True) -> np.ndarray:
- if x is None and accepet_none: return None
- if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.')
+ accept_none: bool=True) -> np.ndarray:
+ if x is None and accept_none: return None
+ if x is None and not accept_none: raise ValueError('Trying to convert an empty value.')
if isinstance(x, np.ndarray): return x
elif isinstance(x, int) or isinstance(x, float): return np.array([x, ])
elif isinstance(x, torch.Tensor):
- if x.numel() == 0 and accepet_none: return None
- if x.numel() == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.')
+ if x.numel() == 0 and accept_none: return None
+ if x.numel() == 0 and not accept_none: raise ValueError('Trying to convert an empty value.')
if x.numel() == 1: return convert_any_to_numpy(x.detach().cpu().item())
if x.numel() > 1: return x.detach().cpu().numpy()
elif isinstance(x, list) or isinstance(x, tuple):
@@ -254,9 +254,9 @@ def convert_any_to_numpy(
def convert_any_to_torch_tensor(
x: Union[torch.Tensor, np.ndarray, int, float, list, tuple],
- accepet_none: bool=True, dtype: torch.dtype=None, device='cpu') -> torch.Tensor:
- if x is None and accepet_none: return None
- if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.')
+ accept_none: bool=True, dtype: torch.dtype=None, device='cpu') -> torch.Tensor:
+ if x is None and accept_none: return None
+ if x is None and not accept_none: raise ValueError('Trying to convert an empty value.')
if isinstance(x, list) or isinstance(x, tuple):
if all([type(element) == int for element in x]):
if dtype is None: dtype=torch.int64
diff --git a/ppq/core/quant.py b/ppq/core/quant.py
index a9681aed..56f555c5 100644
--- a/ppq/core/quant.py
+++ b/ppq/core/quant.py
@@ -140,7 +140,9 @@ class QuantizationProperty(Enum):
PER_TENSOR: Also known as per-layer quantization, which mean all parameters of this layer share the same scale and offset.
(For Convulution layer and Gemm layer which has bias, bias layer will be negative quantized, they do not have a valid scale)
- PER_CHANNEL: parameters are quantized alone channel axis, each channel has a stand-alone scale and offset.
+ PER_CHANNEL: CNN model parameters are quantized along channel axis, each channel has a stand-alone scale and offset.
+
+ PER_CHANNEL_BNC: transformer model parameters are quantized per-chanel, each channel has a stand-alone scale and offset.
LINEAR: Indicates a linear quantization, follow formula: quant(x) = clip(round(x / scale))
@@ -151,18 +153,22 @@ class QuantizationProperty(Enum):
ASYMMETRICAL: Indicates an asymmetrical quantization, offset is activated in this mode.
POWER_OF_2: Indicates a power-of-2 quantization, scale must be pow(2, k) in this mode.
+
+ PTF_BNC: Indicates a power-of-2 for quantization for layernorm input, scale must be pow(2, k) in this mode.
ATTENTION: Not all combinations of all 7 QuantizationProperty are valid, see QuantizationPolicy.__check_valid
ATTENTION: QuantizationPolicy is read-only, user can only assign its value when created, the only interface of
QuantizationPolicy is function QuantizationPolicy.has_property.
"""
- PER_TENSOR = 0x00000001
- PER_CHANNEL = 0x00000002
- LINEAR = 0x00000004
- EXPONENTIAL = 0x00000008
- SYMMETRICAL = 0x00000010
- ASYMMETRICAL = 0x00000020
- POWER_OF_2 = 0x00000040
+ PER_TENSOR = 0x00000001
+ PER_CHANNEL = 0x00000002
+ PER_CHANNEL_BNC = 0x00000080
+ LINEAR = 0x00000004
+ EXPONENTIAL = 0x00000008
+ SYMMETRICAL = 0x00000010
+ ASYMMETRICAL = 0x00000020
+ POWER_OF_2 = 0x00000040
+ PTF_BNC = 0x00000100
def __or__(self, other: int) -> int:
return self.value + other
@@ -247,6 +253,7 @@ def __check_valid(cls, policy):
QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_TENSOR | QuantizationProperty.POWER_OF_2,
QuantizationProperty.ASYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL | QuantizationProperty.POWER_OF_2,
QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL | QuantizationProperty.POWER_OF_2,
+ QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL_BNC | QuantizationProperty.PTF_BNC,
}
def to_dict(self) -> dict:
@@ -658,7 +665,7 @@ class ChannelwiseTensorQuantizationConfig(TensorQuantizationConfig):
"""ChannelwiseTensorQuantizationConfig is a special case for tensor
quantization configuration.
- Comparing with per-tensor quantization configuration, pre-channel quantization has a property
+ Comparing with per-tensor quantization configuration, per-channel quantization has a property
"channel_axis" to indicate a channel axis where quantization takes effects.
Along this axis, all tensor values will be quantized with a sharing scale and offset,
diff --git a/ppq/csrc/cuda/PPQ.h b/ppq/csrc/cuda/PPQ.h
index d5138fb6..d1d5fed8 100644
--- a/ppq/csrc/cuda/PPQ.h
+++ b/ppq/csrc/cuda/PPQ.h
@@ -112,13 +112,16 @@
ATTENTION: QuantizationPolicy is read-only, user can only assign its value when created, the only interface of
QuantizationPolicy is function QuantizationPolicy.has_property.
*/
- # define PPQ_QPROPERTY_PER_TENSOR 0x00000001
- # define PPQ_QPROPERTY_PER_CHANNEL 0x00000002
+ # define PPQ_QPROPERTY_PER_TENSOR 0x00000001
+ # define PPQ_QPROPERTY_PER_CHANNEL 0x00000002
+ # define PPQ_QPROPERTY_PER_CHANNEL_BNC 0x00000080
+
# define PPQ_QPROPERTY_LINEAR 0x00000004
# define PPQ_QPROPERTY_EXPONENTIAL 0x00000008
# define PPQ_QPROPERTY_SYMMETRICAL 0x00000010
# define PPQ_QPROPERTY_ASYMMETRICAL 0x00000020
# define PPQ_QPROPERTY_POWER_OF_2 0x00000040
+ # define PPQ_QPROPERTY_PTF_BNC 0x00000100
# define PPQ_CPP_EXTENSION
# endif
diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py
index e97099a0..fe846350 100644
--- a/ppq/executor/op/torch/default.py
+++ b/ppq/executor/op/torch/default.py
@@ -333,6 +333,55 @@ def Mul_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCont
return multiplicand * multiplier
+def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> list:
+ """Perform MultiHeadAttetion opr forward.
+
+ Args:
+ op (Operation): MultiHeadAttention
+ values (List[torch.Tensor]): opr inputs
+ ctx (TorchBackendContext, optional): Context. Defaults to None.
+
+ Raises:
+ NotImplementedError: In [Vit Paper](https://arxiv.org/abs/2010.11929), MultiHeadAttention inputs are actually the same tensor, we suppose that this would **not** be simplified.
+ ValueError: MultiHeadAttention contains `embed_dim` and `num_heads`.
+
+ Returns:
+ list: opr output and internal result for quantization.
+ """
+ if len(values) != 11:
+ raise NotImplementedError('Not implement simplified MultiHeadAttention')
+
+ q_in,k_in,v_in,q_w,q_b,k_w,k_b,v_w,v_b,o_w,o_b = values
+ embed_dim = op.attributes.get('embed_dim')
+ num_heads = op.attributes.get('num_heads')
+
+ if embed_dim is None or num_heads is None:
+ raise ValueError('Cannot fetch embed_dim or num_heads')
+
+ # setup parameters
+ batch_size = q_in.shape[0]
+ head_dim = embed_dim // num_heads
+ scale = head_dim ** -0.5
+
+ xq = F.linear(q_in, q_w, q_b)
+ xk = F.linear(k_in, k_w, k_b)
+ xv = F.linear(v_in, v_w, v_b)
+
+ B, N, _ = xq.shape
+
+ q = xq.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
+ k = xk.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
+ v = xv.reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
+
+ energy = (q @ k.transpose(-2, -1)) * scale
+ attn = energy.softmax(dim=-1)
+
+ feat = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim)
+ out = F.linear(feat, o_w, o_b)
+
+ return out, q, k, v, energy, feat
+
+
def Add_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
"""Performs element-wise binary addition (with Numpy-style broadcasting
support).
@@ -786,6 +835,9 @@ def GatherND_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken
reshaped_output = output.reshape(*shape_i, *shape_j, *shape_k)
return output
+def Gelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
+ [input_value] = values
+ return F.gelu(input_value)
def Greater_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
input_a, input_b = values
@@ -1436,7 +1488,7 @@ def Split_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCo
split = op.attributes.get('split', 0)
[input_value] = values
if 'split' not in op.attributes:
- split = input_value.shape[axis] // len(op.outputs)
+ split = input_value.shape[axis] // len(op.outputs)
outputs = torch.split(input_value, split, axis)
return outputs
@@ -1525,6 +1577,18 @@ def LeakyRelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke
return output
+def LayerNorm_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs):
+ if len(values) != 3:
+ raise ValueError('Unsupported LayerNorm without affine')
+
+ input_data, weight, bias = values
+ eps = op.attributes.get('epsilon', 1e-5)
+ normalized_shape = weight.shape
+
+ output = F.layer_norm(input_data, normalized_shape, weight, bias, eps)
+ return output
+
+
def Pad_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs):
mode = op.attributes.get('mode', 'constant')
input_data = values[0]
@@ -2118,20 +2182,20 @@ def Identity_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken
return values[0]
def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
- """
- Produces a one-hot tensor based on inputs. The locations represented by the index values in the 'indices'
- input tensor will have 'on_value' and the other locations will have 'off_value' in the output tensor,
-
- where 'on_value' and 'off_value' are specified as part of required input argument 'values',
- which is a two-element tensor of format [off_value, on_value].
-
- The rank of the output tensor will be one greater than the rank of the input tensor.
- The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'.
- If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension,
- i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'.
-
- The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices'
- input tensor with values outside the range [-depth, depth-1] will result in one-hot representation
+ """Produces a one-hot tensor based on inputs. The locations represented by
+ the index values in the 'indices' input tensor will have 'on_value' and the
+ other locations will have 'off_value' in the output tensor,
+
+ where 'on_value' and 'off_value' are specified as part of required input argument 'values',
+ which is a two-element tensor of format [off_value, on_value].
+
+ The rank of the output tensor will be one greater than the rank of the input tensor.
+ The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'.
+ If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension,
+ i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'.
+
+ The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices'
+ input tensor with values outside the range [-depth, depth-1] will result in one-hot representation
with all 'off_value' values in the output tensor.
when axis = 0:
@@ -2144,30 +2208,30 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC
Attributes
axis : int (default is -1)
- (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that
- the additional dimension will be inserted as the innermost/last dimension in the output tensor.
+ (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that
+ the additional dimension will be inserted as the innermost/last dimension in the output tensor.
Negative value means counting dimensions from the back. Accepted range is [-r-1, r] where r = rank(indices).
-
+
Inputs
indices (non-differentiable) : T1
Input tensor containing indices. Any entries in the 'indices' input tensor with values outside the range [-depth, depth-1]
- will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type,
+ will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type,
the values will be casted to int64 before use.
-
+
depth (non-differentiable) : T2
- Scalar specifying the number of classes in one-hot tensor.
+ Scalar specifying the number of classes in one-hot tensor.
This is also the size of the one-hot dimension (specified by 'axis' attribute) added on in the output tensor.
- The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1].
+ The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1].
In case 'depth' is of non-integer type, it will be casted to int64 before use.
values (non-differentiable) : T3
- Rank 1 tensor containing exactly two elements,
- in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor,
+ Rank 1 tensor containing exactly two elements,
+ in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor,
and 'off_value' is the value used for filling locations other than those specified in 'indices' input tensor.
Outputs
output (non-differentiable) : T3
- Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1.
+ Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1.
The data type for the elements of the output tensor is the same as the type of input 'values' is used.
"""
# implementation from https://github.com/ToriML/onnx2pytorch/blob/master/onnx2pytorch/operations/onehot.py
@@ -2187,10 +2251,10 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC
order.insert(axis, -1)
out = out.permute(order)
return out
-
+
def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
"""
- Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is,
+ Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is,
y = 1/x, is applied to the tensor elementwise.
Version
@@ -2231,11 +2295,13 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack
'Gather': Gather_forward,
'GatherElements': Gather_forward,
'GatherND': GatherND_forward,
+ 'Gelu': Gelu_forward,
'Gemm': Gemm_forward,
'grid_sampler': Grid_sampler_forward,
'GlobalAveragePool': AveragePool_forward,
'GlobalMaxPool': MaxPool2d_forward,
'Greater': Greater_forward,
+ 'LayerNorm': LayerNorm_forward,
'LeakyRelu': LeakyRelu_forward,
'Less': Less_forward,
'MatMul': MatMul_forward,
@@ -2243,6 +2309,7 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack
'MaxPool': MaxPool2d_forward,
'Min': Eltwise_forward,
'Mul': Mul_forward,
+ 'MultiHeadAttention': MultiHeadAttention_forward,
'NonMaxSuppression': _NMS_forward,
'NonZero': NonZero_forward,
'Not': Not_forward,
diff --git a/ppq/executor/torch.py b/ppq/executor/torch.py
index 16cd31f3..eae53326 100644
--- a/ppq/executor/torch.py
+++ b/ppq/executor/torch.py
@@ -390,8 +390,8 @@ def __forward(
if output_var.name in output_names:
result_collector[output_names.index(output_var.name)] = outputs[output_idx]
- except Exception as _:
- raise RuntimeError(f'Error happens when dealing with operation {str(operation)}')
+ except Exception as e:
+ raise RuntimeError(f'Error happens when dealing with operation {str(operation)}, {str(e)}')
# remove useless value(runtime clear).
visited_op.append(operation)
diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py
index 7285759a..ba4aac64 100644
--- a/ppq/parser/ncnn_exporter.py
+++ b/ppq/parser/ncnn_exporter.py
@@ -1,3 +1,4 @@
+from lib2to3.pytree import convert
from typing import List
from ppq.core import (DataType, NetworkFramework, QuantizationProperty,
@@ -8,12 +9,48 @@
from .onnx_exporter import OnnxExporter
from .util import convert_value
+import toml
+
+# rewrite toml encoder
+class ArrayEncoder(toml.TomlEncoder):
+
+ def __init__(self, _dict=dict, preserve=False, separator=","):
+ super(ArrayEncoder, self).__init__(_dict, preserve)
+ if separator.strip() == "":
+ separator = "," + separator
+ elif separator.strip(' \t\n\r,'):
+ raise ValueError("Invalid separator for arrays")
+ self.separator = separator
+
+ def dump_list(self, v):
+ t = []
+ retval = "["
+ for u in v:
+ t.append(self.dump_value(u))
+ while t != []:
+ s = []
+ last = len(t) - 1
+ for idx, u in enumerate(t):
+ if isinstance(u, list):
+ for r in u:
+ s.append(r)
+ elif idx != last:
+ retval += " " + str(u) + self.separator
+ else:
+ retval += " " + str(u) + " "
+ t = s
+ retval += "]"
+ return retval
class NCNNExporter(GraphExporter):
- def export_quantization_config(self, config_path: str, graph: BaseGraph):
+ ''' raw format only support Conv and Gemm quantization '''
+ def export_raw_quant_config(self, config_path: str, graph: BaseGraph):
+ ''' ncnn table format when version <= 20220629 '''
fd = open(config_path, 'w+')
topo_order = graph.topological_sort()
for op in topo_order:
+ if op.type not in {'Conv', 'Gemm'}:
+ continue
if op.is_computing_op and isinstance(op, QuantableOperation):
fd.write(f'{op.name}_param_0 ')
param_cfg = op.config.input_quantization_config[1]
@@ -32,6 +69,8 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
fd.write('%f '% s)
fd.write('\n')
for op in topo_order:
+ if op.type not in {'Conv', 'Gemm'}:
+ continue
if op.is_computing_op and isinstance(op, QuantableOperation):
fd.write(f'{op.name} ')
input_cfg = op.config.input_quantization_config[0]
@@ -42,6 +81,114 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
fd.write('\n')
fd.close()
+
+ def export_ini_quant_config(self, config_path: str, graph: BaseGraph):
+ ''' toml is human readable format '''
+ order = graph.topological_sort()
+ table = {}
+ for op in order:
+ if hasattr(op, 'config'):
+ item = dict()
+ # avoiding Gather to Crop, we cannot deduce opr_type from opr_name
+ item['type'] = op.type
+ if op.type in {'Conv', 'Gemm'}:
+ input_cfg = op.config.input_quantization_config[0]
+ assert input_cfg.state == QuantizationStates.ACTIVATED and \
+ input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR)
+ item['input_scale'] = convert_value(1 / input_cfg.scale, True, DataType.FP32)
+
+ param_cfg = op.config.input_quantization_config[1]
+ assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\
+ and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \
+ param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL)
+ # a workaround for depthwise conv in ncnn
+ # will cause mis-alignment between ppq and ncnn
+ if op.type == 'Conv' and op.attributes.get('group', 1) > 1:
+ group = op.attributes.get('group', 1)
+ scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0]
+ else:
+ scale = param_cfg.scale
+ item['weight'] = convert_value(1 / scale, False, DataType.FP32)
+
+ # elif op.type in {'Add'}:
+ # # Add may have multiple input node
+ # input_scales = []
+ # for cfg_in in op.config.input_quantization_config:
+ # assert cfg_in.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \
+ # and cfg_in.observer_algorithm in {'minmax', 'Minmax'}
+ # input_scales.append(convert_value(1.0 / cfg_in.scale, True, DataType.FP32))
+ # item['input_scales'] = input_scales
+
+ # cfg_out = op.config.output_quantization_config[0]
+ # item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32)
+
+ elif op.type in {'Gelu'}:
+ cfg = op.config.input_quantization_config[0]
+
+ assert cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \
+ and cfg.observer_algorithm in {'minmax', 'Minmax'}
+ item['input_scale'] = convert_value(1.0 / cfg.scale, True, DataType.FP32)
+
+ elif op.type in {'LayerNorm'}:
+ cfg_in = op.config.input_quantization_config[0]
+ cfg_out = op.config.output_quantization_config[0]
+
+ assert cfg_in.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED} \
+ and cfg_in.observer_algorithm in {'minmax', 'Minmax'} \
+ and cfg_out.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}
+ item['input_scales'] = convert_value(1.0 / cfg_in.scale, False, DataType.FP32)
+ item['output_scale'] = convert_value(1.0 / cfg_out.scale, True, DataType.FP32)
+
+ elif op.type == 'MultiHeadAttention':
+ # write input scale
+ cfg_q_in = op.config.input_quantization_config[0]
+ cfg_k_in = op.config.input_quantization_config[1]
+ cfg_v_in = op.config.input_quantization_config[2]
+
+ item['input_scale_q'] = convert_value(1.0 / cfg_q_in.scale, True, DataType.FP32)
+ item['input_scale_k'] = convert_value(1.0 / cfg_k_in.scale, True, DataType.FP32)
+ item['input_scale_v'] = convert_value(1.0 / cfg_v_in.scale, True, DataType.FP32)
+
+ # write input/output weight scale, per-channel
+ cfg_q_w = op.config.input_quantization_config[3]
+ cfg_k_w = op.config.input_quantization_config[5]
+ cfg_v_w = op.config.input_quantization_config[7]
+ cfg_o_w = op.config.input_quantization_config[9]
+
+ item['weight_q'] = convert_value(1 / cfg_q_w.scale, False, DataType.FP32)
+ item['weight_k'] = convert_value(1 / cfg_k_w.scale, False, DataType.FP32)
+ item['weight_v'] = convert_value(1 / cfg_v_w.scale, False, DataType.FP32)
+ item['weight_o'] = convert_value(1 / cfg_o_w.scale, False, DataType.FP32)
+
+ # write internal scale
+ cfg_q = op.config.output_quantization_config[1]
+ cfg_k = op.config.output_quantization_config[2]
+ cfg_v = op.config.output_quantization_config[3]
+ cfg_energy = op.config.output_quantization_config[4]
+ cfg_feat = op.config.output_quantization_config[5]
+
+ item['internal_scale_q'] = convert_value(1.0 / cfg_q.scale, True, DataType.FP32)
+ item['internal_scale_k'] = convert_value(1.0 / cfg_k.scale, True, DataType.FP32)
+ item['internal_scale_v'] = convert_value(1.0 / cfg_v.scale, True, DataType.FP32)
+ item['internal_scale_energy'] = convert_value(1.0 / cfg_energy.scale, True, DataType.FP32)
+ item['internal_scale_feat'] = convert_value(1.0 / cfg_feat.scale, True, DataType.FP32)
+
+ else:
+ print('unknown quant type {} name {} during write weight scale'.format(op.type, op.name))
+ continue
+
+ table[op.name] = item
+
+ toml.dump(table, open(config_path, 'w+'), encoder=ArrayEncoder())
+
+
+ def export_quantization_config(self, config_path: str, graph: BaseGraph):
+ if config_path.endswith(".ini"):
+ print("export .ini format quant table, please make sure ncnn version >= 20220627")
+ self.export_ini_quant_config(config_path=config_path, graph=graph)
+ else:
+ self.export_raw_quant_config(config_path=config_path, graph=graph)
+
def export(self, file_path: str, graph: BaseGraph, config_path: str = None, input_shapes: List[List[int]] = [[1, 3, 224, 224]]):
if config_path is not None:
self.export_quantization_config(config_path, graph)
diff --git a/ppq/parser/onnx_exporter.py b/ppq/parser/onnx_exporter.py
index ac2231fe..149a70c6 100644
--- a/ppq/parser/onnx_exporter.py
+++ b/ppq/parser/onnx_exporter.py
@@ -1,3 +1,4 @@
+from ast import operator
import json
from typing import Union
@@ -60,7 +61,7 @@ def export(self, operation: Operation, graph: BaseGraph, **kwargs) -> Operation:
def convert_value(value: Union[int, float, np.ndarray, torch.Tensor]) -> str:
if type(value) in {int, float}: return value
else:
- value = convert_any_to_numpy(value, accepet_none=True)
+ value = convert_any_to_numpy(value, accept_none=True)
if value is None: return value # SOI config has Nona as its scale and
return value.tolist()
@@ -113,7 +114,7 @@ def export_operation(self, operation: Operation) -> onnx.OperatorProto:
assert isinstance(exporter, OperationExporter), (
f'Expected an OpExporter here, however {type(exporter)} was given.')
operation = exporter.export(operation=operation, graph=None)
-
+
attributes = operation.attributes
for key in attributes:
value = attributes[key]
@@ -165,7 +166,15 @@ def export_var(self, variable: Variable) -> onnx.TensorProto:
dims=shape, vals=value)
return tensor_proto
+ def remove_fake_node_output(self, graph: BaseGraph):
+ for opr in graph.topological_sort():
+ if opr.type == 'MultiHeadAttention':
+ for var in opr.outputs[1:]:
+ graph.remove_variable(var)
+
def export(self, file_path: str, graph: BaseGraph, config_path: str = None):
+ self.remove_fake_node_output(graph)
+
# during export we will remove all boundary operations from graph.
# we do not want to change the structure of original graph,
# so there have to take a clone of it.
diff --git a/ppq/parser/util.py b/ppq/parser/util.py
index c2bffb9f..17c7c928 100644
--- a/ppq/parser/util.py
+++ b/ppq/parser/util.py
@@ -21,7 +21,7 @@ def convert_value(
if dtype not in {DataType.FP32, DataType.INT32}:
raise ValueError(f'Can Only export dtype fp32 and int32, '
f'while you are requiring to dump a {dtype.name} value')
- value = convert_any_to_numpy(value, accepet_none=False)
+ value = convert_any_to_numpy(value, accept_none=False)
value = value.astype(dtype=DataType.to_numpy(dtype))
if export_as_float:
value = np.asscalar(value[0])
@@ -30,5 +30,5 @@ def convert_value(
f'It is Expected to be a int or float value, while {type(value)} was given')
return value
else:
- value = convert_any_to_numpy(value, accepet_none=False)
+ value = convert_any_to_numpy(value, accept_none=False)
return value.tolist()
diff --git a/ppq/quantization/algorithm/training.py b/ppq/quantization/algorithm/training.py
index 7b5f2d0d..9712111c 100644
--- a/ppq/quantization/algorithm/training.py
+++ b/ppq/quantization/algorithm/training.py
@@ -221,9 +221,11 @@ def PPQTensorClip(
limit: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor:
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
assert isinstance(config, ChannelwiseTensorQuantizationConfig)
- return Clip_C.apply(tensor, reference, limit, config.channel_axis)
+ return Clip_C.apply(tensor, reference, limit, config.channel_axis)
elif config.policy.has_property(QuantizationProperty.PER_TENSOR):
return Clip_T.apply(tensor, reference, limit)
+ elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise Exception('Not implement PER_CHANNEL_BNC in PPQTensorClip')
else: raise Exception('Oops, seems we got some problems here.')
@@ -240,6 +242,8 @@ def PPQRoundingLoss(tensor: torch.Tensor,
return RoundingLoss_T.apply(
tensor, config.scale, config.offset, config.quant_min,
config.quant_max, config.rounding)
+ elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise Exception('Not implement PER_CHANNEL_BNC in PPQRoundingLoss')
else: raise Exception('Oops, seems we got some problems here.')
@@ -743,7 +747,10 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to
shape = [1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)]
scale = scale.view(shape)
bias = bias.view(shape)
-
+
+ if config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise Exception('Not implement PER_CHANNEL_BNC in StraightThroughEstimateDelegator')
+
# only bias doesn't need offset in asym quant
if not self.passive and config.policy.has_property(QuantizationProperty.ASYMMETRICAL):
tensor = tensor + bias.abs()
@@ -782,6 +789,9 @@ def initiate_rounding(self) -> Union[None, torch.nn.Parameter]:
assert isinstance(self.config, ChannelwiseTensorQuantizationConfig)
shape = [1 if axis != self.config.channel_axis else -1 for axis in range(weight.ndim)]
scale = scale.view(shape)
+ elif self.config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise Exception('Not implement PER_CHANNEL_BNC in StraightThroughEstimateDelegator')
+
round_diff = (weight / scale) - (weight / scale).floor()
v_init = -torch.log((self.reg.zeta - self.reg.gamma) / (round_diff - self.reg.gamma) - 1)
continuous_v = torch.nn.Parameter(v_init, True)
@@ -813,6 +823,9 @@ def finalize(self) -> None:
shape = [1 if axis != self.config.channel_axis else -1 for axis in range(weight.ndim)]
scale = scale.view(shape)
offset = offset.view(shape)
+ elif self.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise Exception('Not implement PER_CHANNEL_BNC in BlockwiseReconstructionDelegator finalize')
+
weight = (weight / scale).floor() + (self.rounding >= 0).float()
weight = torch.clamp(weight + offset, self.config.quant_min, self.config.quant_max)
weight = (weight - offset) * scale
@@ -829,6 +842,9 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to
shape = [1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)]
scale = scale.view(shape)
offset = offset.view(shape)
+ elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise Exception('Not implement PER_CHANNEL_BNC in BlockwiseReconstructionDelegator __call__')
+
tensor = (tensor / scale).floor() + self.reg.rectified_sigmoid(self.rounding)
tensor = torch.clamp(tensor + offset, config.quant_min, config.quant_max)
tensor = (tensor - offset) * scale
diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py
index 53d085c6..73deff82 100644
--- a/ppq/quantization/observer/range.py
+++ b/ppq/quantization/observer/range.py
@@ -14,7 +14,49 @@
from ppq.utils.round import ppq_numerical_round, ppq_round_to_power_of_2
from .base import BaseTensorObserver
+from .utils import lp_loss
+# https://github.com/megvii-research/FQ-ViT/blob/main/models/ptq/observer/ptf.py#L31
+@ ppq_quant_param_computing_function
+def PTF_BNC_to_scale_offset(
+ min_val: list, max_val: list,
+ inputs: torch.Tensor,
+ config: TensorQuantizationConfig
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ max_val = torch.Tensor(max_val)
+ min_val = torch.Tensor(min_val)
+
+ qmax = config.quant_max
+ qmin = config.quant_min
+
+ max_val_t = max_val.max()
+ min_val_t = min_val.min()
+ scale8 = (max_val_t - min_val_t) / float(qmax - qmin)
+ scale8.clamp_(1e-7)
+ scale4 = scale8 / 2
+ scale2 = scale4 / 2
+ scale1 = scale2 / 2
+ zero_point = qmin - torch.round(min_val_t / scale8)
+ zero_point.clamp_(qmin, qmax)
+ scale_mask = torch.ones_like(max_val)
+ for j in range(inputs.shape[2]):
+ data = inputs[..., j].unsqueeze(-1)
+ data_q1 = ((data / scale1 + zero_point).round().clamp(qmin, qmax) -
+ zero_point) * scale1
+ data_q2 = ((data / scale2 + zero_point).round().clamp(qmin, qmax) -
+ zero_point) * scale2
+ data_q4 = ((data / scale4 + zero_point).round().clamp(qmin, qmax) -
+ zero_point) * scale4
+ data_q8 = ((data / scale8 + zero_point).round().clamp(qmin, qmax) -
+ zero_point) * scale8
+ score1 = lp_loss(data, data_q1, p=2.0, reduction='all')
+ score2 = lp_loss(data, data_q2, p=2.0, reduction='all')
+ score4 = lp_loss(data, data_q4, p=2.0, reduction='all')
+ score8 = lp_loss(data, data_q8, p=2.0, reduction='all')
+ score = [score1, score2, score4, score8]
+ scale_mask[j] *= 2**score.index(min(score))
+ scale = scale1 * scale_mask
+ return scale, zero_point
@ ppq_quant_param_computing_function
def minmax_to_scale_offset(
@@ -48,6 +90,7 @@ def __init__(self, watch_on: Variable, quant_cfg: TensorQuantizationConfig):
super().__init__(watch_on, quant_cfg)
self._min_val_collector = []
self._max_val_collector = []
+ self._last_input = None
@ torch.no_grad()
def observe(self, value: torch.Tensor):
@@ -66,6 +109,13 @@ def observe(self, value: torch.Tensor):
channelwise_view = torch.flatten(channelwise_view, start_dim=1)
self._min_val_collector.append(torch.min(channelwise_view, dim=1, keepdim=True)[0])
self._max_val_collector.append(torch.max(channelwise_view, dim=1, keepdim=True)[0])
+ elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ assert len(value.shape) == 3
+ channelwise_view = value.reshape(-1, value.shape[-1])
+ channelwise_view = channelwise_view.transpose(0,1)
+ self._min_val_collector.append(torch.min(channelwise_view, dim=1, keepdim=True)[0])
+ self._max_val_collector.append(torch.max(channelwise_view, dim=1, keepdim=True)[0])
+ self._last_input = value
else:
raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.')
@@ -100,6 +150,19 @@ def render_quantization_config(self):
self._quant_cfg.scale = torch.tensor(scales, dtype=torch.float32, device=device)
self._quant_cfg.offset = torch.tensor(offsets, dtype=torch.float32, device=device)
self._quant_cfg.state = QuantizationStates.ACTIVATED
+ elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+
+ min_vals = torch.min(torch.cat(self._min_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy()
+ max_vals = torch.max(torch.cat(self._max_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy()
+ assert(len(min_vals) == len(max_vals)), 'Min values and max values should at same length.'
+
+ scales, offsets = PTF_BNC_to_scale_offset(min_val=min_vals, max_val=max_vals, inputs=self._last_input, config=self._quant_cfg)
+
+ # scale, offset here only deployed on cpu
+ # we will move them towards target device through RunnableGraph
+ self._quant_cfg.scale = torch.tensor(scales.clone().detach(), dtype=torch.float32, device=device)
+ self._quant_cfg.offset = torch.tensor(offsets.clone().detach(), dtype=torch.float32, device=device)
+ self._quant_cfg.state = QuantizationStates.ACTIVATED
else:
raise TypeError('Min-max Observer only work with per-tensor or per-channel quantize policy.')
@@ -260,10 +323,12 @@ def observe(self, value: torch.Tensor):
self._percentile_collector.append(torch.cat([_max, _min], dim=-1))
else:
self._percentile_collector.append(CUDA.Quantile(value, self._percentile).view(1, -1))
- elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL):
- raise PermissionError('Percentile observer can not deal with per channel quantization.')
+ elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or \
+ self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ # raise PermissionError('Percentile observer can not deal with per channel quantization.')
+
assert isinstance(self._quant_cfg, ChannelwiseTensorQuantizationConfig), (
- 'Your quantization config has PER_CHANNEL while it is not a '
+ 'Your quantization config has PER_CHAN=NEL while it is not a '
'ChannelwiseTensorQuantizationConfig instance.')
channel_axis = self._quant_cfg.channel_axis
channelwise_view = value.transpose(dim0=0, dim1=channel_axis)
@@ -288,8 +353,11 @@ def render_quantization_config(self):
self._quant_cfg.scale = torch.tensor([scale], dtype=torch.float32, device=device).squeeze(0)
self._quant_cfg.offset = torch.tensor([offset], dtype=torch.float32, device=device).squeeze(0)
self._quant_cfg.state = QuantizationStates.ACTIVATED
- elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL):
- raise PermissionError('Percentile observer can not deal with per channel quantization.')
+ elif self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or \
+ self._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ # import pdb
+ # pdb.set_trace()
+ # raise PermissionError('Percentile observer can not deal with per channel quantization.')
if len(self._percentile_maxs) == 0:
raise ValueError('Can not render quantization config yet, Observer data collator is empty. ' \
'Invoke observe() function before render config.')
@@ -348,6 +416,9 @@ def hist_to_scale_offset(
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
raise PermissionError('Torch Mse observer do not support PER_CHANNEL policy now, please wait.')
+ if config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise PermissionError('Torch Mse observer do not support PER_CHANNEL_BNC policy now, please wait.')
+
if (config.policy.has_property(QuantizationProperty.SYMMETRICAL) and
config.policy.has_property(QuantizationProperty.PER_TENSOR)):
scale = 1 # hist scale
diff --git a/ppq/quantization/observer/utils.py b/ppq/quantization/observer/utils.py
new file mode 100644
index 00000000..15d6032b
--- /dev/null
+++ b/ppq/quantization/observer/utils.py
@@ -0,0 +1,10 @@
+# Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved.
+# https://github.com/megvii-research/FQ-ViT
+def lp_loss(pred, tgt, p=2.0, reduction='none'):
+ """
+ loss function measured in L_p Norm
+ """
+ if reduction == 'none':
+ return (pred - tgt).abs().pow(p).sum(1).mean()
+ else:
+ return (pred - tgt).abs().pow(p).mean()
diff --git a/ppq/quantization/optim/calibration.py b/ppq/quantization/optim/calibration.py
index 87a9f443..94acd4b0 100644
--- a/ppq/quantization/optim/calibration.py
+++ b/ppq/quantization/optim/calibration.py
@@ -311,11 +311,12 @@ def optimize(self, processor: GraphCommandProcessor, dataloader: Iterable,
assert isinstance(cfg, TensorQuantizationConfig)
assert isinstance(observer, TorchMinMaxObserver)
- if observer._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL):
+ if observer._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) or\
+ observer._quant_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
min_vals = torch.min(torch.cat(observer._min_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy()
max_vals = torch.max(torch.cat(observer._max_val_collector, dim=-1), dim=-1, keepdim=False)[0].cpu().numpy()
cfg.detail.update({'range_min': min_vals, 'range_max': max_vals})
-
+
elif observer._quant_cfg.policy.has_property(QuantizationProperty.PER_TENSOR):
min_val = torch.min(torch.cat(observer._min_val_collector, dim=0)).cpu().item(),
max_val = torch.max(torch.cat(observer._max_val_collector, dim=0)).cpu().item(),
diff --git a/ppq/quantization/optim/training.py b/ppq/quantization/optim/training.py
index 81a46912..ad4d5eee 100644
--- a/ppq/quantization/optim/training.py
+++ b/ppq/quantization/optim/training.py
@@ -347,7 +347,7 @@ class BiasCorrectionPass(TrainingBasedPass):
def __init__(self, auto_check: bool=False, interested_output: List[str] = None,
verbose: bool = True, max_steps:int = 8) -> None:
"""Quantization can introduce a biased error in the activations. Bias
- correction serves as a useful prosedure to eliminate those introduced
+ correction serves as a useful procedure to eliminate those introduced
bias error.
let: Y = WX + b
@@ -379,7 +379,12 @@ def collect_bias(output: torch.Tensor, collector: list, op_type: str):
if op_type in {'Conv', 'ConvTranspose'}:
collector.append(torch.mean(output, dim=(0, 2, 3)).unsqueeze(0))
elif op_type in {'Gemm'}:
- collector.append(torch.mean(output, dim=(0, )).unsqueeze(0))
+ if len(output.shape) == 2:
+ collector.append(torch.mean(output, dim=(0, )).unsqueeze(0))
+ elif len(output.shape) == 3:
+ collector.append(torch.mean(output, dim=(0,1)).unsqueeze(0))
+ else:
+ raise ValueError(f'Unsupported Gemm shape: {output.shape}')
else: raise TypeError(f'Unsupported Operation type: {op_type}')
assert isinstance(executor, TorchExecutor), (
@@ -490,6 +495,8 @@ def optimize(
for axis in range(fp_weight.ndim)]
weight_scale = weight_scale.view(view_shape)
weight_offset = weight_offset.view(view_shape)
+ elif weight_quantization_config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ raise Exception('Not implement PER_CHANNEL_BNC in AdaRoundPass optimize')
# init continuous_v, make sure h(v) = round_diff
round_diff = (fp_weight / weight_scale) - (fp_weight / weight_scale).floor()
diff --git a/ppq/quantization/qfunction/linear.py b/ppq/quantization/qfunction/linear.py
index 95310b20..aaede2b7 100644
--- a/ppq/quantization/qfunction/linear.py
+++ b/ppq/quantization/qfunction/linear.py
@@ -18,7 +18,6 @@ def __init__(self) -> None:
def __call__(self, input_tensor: Any, quantization_config: TensorQuantizationConfig, **kwargs) -> Any:
return super().__call__(input_tensor, quantization_config, **kwargs)
-
if not PPQ_CONFIG.USING_CUDA_KERNEL:
class TensorwiseLinearQuantImpl(Function):
"""Torch Tensorwise quantize is designed to quantize a torch Tensor
@@ -180,6 +179,13 @@ def PPQLinearQuantFunction(
tensor, config.scale, config.offset, config.channel_axis,
config.quant_min, config.quant_max, config.rounding,
dropout)
+ elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ if config.channel_axis != 2:
+ raise ValueError('opr using BNC format while channel_axis != 2')
+ return ChannelwiseLinearQuantImpl.apply(
+ tensor, config.scale, config.offset, config.channel_axis,
+ config.quant_min, config.quant_max, config.rounding,
+ dropout)
elif config.policy.has_property(QuantizationProperty.PER_TENSOR):
return TensorwiseLinearQuantImpl.apply(
tensor, config.scale, config.offset,
@@ -200,6 +206,12 @@ def PPQLinearQuant_toInt(tensor: torch.Tensor, config: TensorQuantizationConfig,
elif config.policy.has_property(QuantizationProperty.PER_TENSOR):
tensor = ppq_tensor_round((tensor / config.scale), config.rounding) + config.offset
tensor = torch.clamp(tensor, config.quant_min, config.quant_max)
+ elif config.policy.has_property(QuantizationProperty.PER_CHANNEL_BNC):
+ assert len(tensor.shape()) == 3
+ scale = config.scale.reshape((1,1,-1))
+ offset = config.offset.reshape((1,1,-1))
+ tensor = ppq_tensor_round((tensor / scale), config.rounding) + offset
+ tensor = torch.clamp(tensor, config.quant_min, config.quant_max)
return tensor.type(dtype=torch.int32)
diff --git a/ppq/quantization/quantizer/NCNNQuantizer.py b/ppq/quantization/quantizer/NCNNQuantizer.py
index 91c5cf0a..9f9fba35 100644
--- a/ppq/quantization/quantizer/NCNNQuantizer.py
+++ b/ppq/quantization/quantizer/NCNNQuantizer.py
@@ -36,10 +36,10 @@ def init_quantize_config(
policy=self.quantize_policy, rounding=self.rounding_policy,
operation_meta=operation.meta_data, num_of_bits=self._num_of_bits,
quant_max=self._quant_max, quant_min=self._quant_min,
- observer_algorithm='percentile'
+ observer_algorithm='Minmax'
)
- if operation.type in {'Conv', 'Gemm'}:
+ if operation.type in {'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm', 'Gelu'}:
assert operation.num_of_input > 0, 'Seems you got a Computing layer with no parameters.'
if operation.type == 'Conv':
@@ -81,13 +81,90 @@ def init_quantize_config(
offsets = None, scales = None, channel_axis = 0
)
base_quant_config.input_quantization_config[1].observer_algorithm = 'Minmax'
-
- # if operation has bias
- if operation.num_of_input > 2:
- bias_config = base_quant_config.input_quantization_config[-1]
- bias_config.state = QuantizationStates.FP32
-
- base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32
+
+ elif operation.type == 'LayerNorm':
+ # LayerNorm 输入按 power of 2 量化
+ inp_config = base_quant_config.input_quantization_config[0]
+ inp_config.policy = QuantizationPolicy(
+ QuantizationProperty.SYMMETRICAL +
+ QuantizationProperty.LINEAR +
+ QuantizationProperty.PER_CHANNEL_BNC +
+ QuantizationProperty.PTF_BNC
+ )
+ base_quant_config.input_quantization_config[0] = \
+ ChannelwiseTensorQuantizationConfig.convert_from_tensor_config(
+ convert_from = inp_config,
+ offsets = None, scales = None, channel_axis = 2
+ )
+ base_quant_config.input_quantization_config[0].observer_algorithm = 'Minmax'
+
+ # layerNorm weight 和 bias 都不量化
+ wconfig = base_quant_config.input_quantization_config[1]
+ bconfig = base_quant_config.input_quantization_config[2]
+ wconfig.state = QuantizationStates.FP32
+ bconfig.state = QuantizationStates.FP32
+
+ # 输出量化
+ output_policy = QuantizationPolicy(
+ QuantizationProperty.SYMMETRICAL +
+ QuantizationProperty.LINEAR +
+ QuantizationProperty.PER_TENSOR
+ )
+ base_quant_config.output_quantization_config[0].policy = output_policy
+ base_quant_config.output_quantization_config[0].observer_algorithm = 'Minmax'
+
+ elif operation.type == 'MultiHeadAttention':
+ # setup input quant param
+ input_policy = QuantizationPolicy(
+ QuantizationProperty.SYMMETRICAL +
+ QuantizationProperty.LINEAR +
+ QuantizationProperty.PER_TENSOR
+ )
+ input_indexes = [0, 1, 2]
+ for index in input_indexes:
+ base_quant_config.input_quantization_config[index].policy = input_policy
+ base_quant_config.input_quantization_config[index].observer_algorithm = 'Minmax'
+
+ # setup weight quant param
+ fc_weight_config = base_quant_config.input_quantization_config[3]
+ fc_weight_config.policy = QuantizationPolicy(
+ QuantizationProperty.SYMMETRICAL +
+ QuantizationProperty.LINEAR +
+ QuantizationProperty.PER_CHANNEL
+ )
+
+ # setup qkv weight quant param
+ fc_weight_indexes = [3, 5, 7, 9]
+ for index in fc_weight_indexes:
+ base_quant_config.input_quantization_config[index] = \
+ ChannelwiseTensorQuantizationConfig.convert_from_tensor_config(
+ convert_from = fc_weight_config,
+ offsets = None, scales = None, channel_axis = 0
+ )
+ base_quant_config.input_quantization_config[index].observer_algorithm = 'Minmax'
+
+ # bias not quant
+ bias_indexes = [4, 6, 8, 10]
+ for index in bias_indexes:
+ base_quant_config.input_quantization_config[index].state = QuantizationStates.FP32
+
+ # setup internal quant policy
+ # here be dragons: we treat internal result as opr output
+ internal_policy = QuantizationPolicy(
+ QuantizationProperty.SYMMETRICAL +
+ QuantizationProperty.LINEAR +
+ QuantizationProperty.PER_TENSOR
+ )
+
+ internal_output_indexes = [1, 2, 3, 4, 5]
+ for index in internal_output_indexes:
+ base_quant_config.output_quantization_config[index].policy = internal_policy
+ base_quant_config.output_quantization_config[index].observer_algorithm = 'Minmax'
+
+
+ # 显式说明输出不量化
+ if operation.type not in {'MultiHeadAttention', 'LayerNorm', 'Add'}:
+ base_quant_config.output_quantization_config[0].state = QuantizationStates.FP32
return base_quant_config
@ property
@@ -101,7 +178,8 @@ def default_platform(self) -> TargetPlatform:
@ property
def quant_operation_types(self) -> set:
return {
- 'Conv', 'Gemm'
+ 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm'
+ # 'Add', 'Conv', 'LayerNorm', 'MultiHeadAttention', 'Gemm'
}
@ property
diff --git a/ppq/quantization/quantizer/base.py b/ppq/quantization/quantizer/base.py
index 0556f49d..491f6b45 100644
--- a/ppq/quantization/quantizer/base.py
+++ b/ppq/quantization/quantizer/base.py
@@ -110,12 +110,12 @@ def quantize_operations(
operation_platforms[op_name] = self.target_platform
else: operation_platforms[op_name] = self.default_platform
- # maunnl override.
+ # manual override.
if op_name in operation_platforms:
operation.platform = operation_platforms[op_name]
# build operation_quantization_configs
- # every quantable op MUST have a quantization config
+ # every quantizable op MUST have a quantization config
# if operation.type is listed in quantable_operation_types while a operation_quantization_configs is given
# it will override the setting of quantable_operation_types
for op_name, operation in self._graph.operations.items():
diff --git a/requirements.txt b/requirements.txt
index 03756530..7e8af6a3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,3 +3,4 @@ onnx >= 1.8.1
protobuf
torch >= 1.6.0
tqdm
+toml
\ No newline at end of file