Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/pages/instructions/ppq_quant_1.html
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ <h2 id="headings">QuantizationPolicy</h2>
<p>QuantizationPolicy 在 PPQ 中用来描述量化策略,它是一些 QuantizationProperty 枚举的组合位图。在 PPQ 中我们支持的 QuantizationProperty 包括:</p>
<ul>
<li><strong>PER_TENSOR:</strong>逐层量化。</li>
<li><strong>PER_CHANNEL:</strong>逐通道量化。</li>
<li><strong>PER_CHANNEL:</strong>CNN 模型逐通道量化。</li>
<li><strong>PER_CHANNEL_BNC:</strong>Transformer 模型逐通道量化。</li>
<li><strong>LINEAR: </strong>线性量化。</li>
<li><strong>EXPONENTIAL: </strong> 指数量化。</li>
<li><strong>SYMMETRICAL: </strong>对称量化。</li>
Expand Down
2 changes: 2 additions & 0 deletions ppq/IR/base/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class GraphCommandType(Enum):
FORMAT_SLICE = 29
# 从一个指定位置将图截断
TRUNCATE_ON_VAR = 30
# 输出 MultiHeadAttention 中间过程,从而让 Quantizer 配置是否量化
FORMAT_MHA = 31

class GraphCommand():
def __init__(self, command_type: GraphCommandType, **kwargs) -> None:
Expand Down
12 changes: 6 additions & 6 deletions ppq/IR/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def retrieve(self):
for _, variable in self._graph.variables.items():
assert isinstance(variable, Variable), \
f'Failed to send graph to device, incorrect variable {variable} found.'
variable.value = convert_any_to_numpy(variable.value, accepet_none=True)
variable.value = convert_any_to_numpy(variable.value, accept_none=True)

return self

Expand All @@ -84,12 +84,12 @@ def deploy(self, device: str):
if operator.type == 'Constant' and operator.platform != TargetPlatform.SHAPE_OR_INDEX:
operator.attributes['value'] = \
convert_any_to_torch_tensor(
operator.attributes['value'], accepet_none=False).to(device)
operator.attributes['value'], accept_none=False).to(device)

if operator.type == 'Constant' and operator.platform == TargetPlatform.SHAPE_OR_INDEX:
value = operator.attributes['value']
operator.attributes['value'] = convert_any_to_torch_tensor(
value, accepet_none=False, device='cpu')
value, accept_none=False, device='cpu')

for _, variable in self._graph.variables.items():
assert isinstance(variable, Variable), \
Expand All @@ -108,10 +108,10 @@ def deploy(self, device: str):
# if all downstream operations are shape related operations, send value to cpu
if platform == TargetPlatform.SHAPE_OR_INDEX:
variable.value = convert_any_to_torch_tensor(
variable.value, accepet_none=True).to('cpu')
variable.value, accept_none=True).to('cpu')
else:
variable.value = convert_any_to_torch_tensor(
variable.value, accepet_none=True).to(device=device)
variable.value, accept_none=True).to(device=device)

# if variable is a shape-related variable, send it to cpu.
if variable.is_parameter:
Expand All @@ -125,5 +125,5 @@ def deploy(self, device: str):
if dest_op.type in {'Reshape', 'Slice', 'Gather', 'Pad', 'Resize', 'Split', 'TopK', 'Tile', 'Expand'}:
if dest_idx >= 1 and len(variable.dest_ops) == 1:
variable.value = convert_any_to_torch_tensor(
variable.value, accepet_none=True).to('cpu')
variable.value, accept_none=True).to('cpu')
return self
37 changes: 33 additions & 4 deletions ppq/IR/morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def _acceptable_command_types(self) -> List[GraphCommandType]:
GraphCommandType.FORMAT_PARAMETERS,
GraphCommandType.FORMAT_CONSTANT_INPUT,
GraphCommandType.FORMAT_SLICE,
GraphCommandType.TRUNCATE_ON_VAR
GraphCommandType.TRUNCATE_ON_VAR,
GraphCommandType.FORMAT_MHA,
]

def process(self, command: GraphCommand) -> Any:
Expand All @@ -107,7 +108,7 @@ def process(self, command: GraphCommand) -> Any:
if command.command_type == GraphCommandType.FORMAT_INT64_CONSTANT:
return self.format_int64_constant()
if command.command_type == GraphCommandType.REPLACE_SUB:
return self.replace_substarction()
return self.replace_substraction()
if command.command_type == GraphCommandType.FORMAT_PARAMETERS:
return self.format_parameter_variables()
if command.command_type == GraphCommandType.FORMAT_CONSTANT_INPUT:
Expand All @@ -117,6 +118,8 @@ def process(self, command: GraphCommand) -> Any:
if command.command_type == GraphCommandType.TRUNCATE_ON_VAR:
assert isinstance(command, TruncateGraphCommand), f'Use TruncateGraphCommand here.'
return self.truncate_on_var(command.var, command.mark_as_output)
if command.command_type == GraphCommandType.FORMAT_MHA:
return self.format_mha()

def format_slice(self) -> None:
"""
Expand Down Expand Up @@ -398,8 +401,34 @@ def format_parameter_variables(self) -> None:

# pop variable from graph
self.graph.remove_variable(var)

def replace_substarction(self) -> None:

def format_mha(self) -> None:
mha = []
for operation in self.graph.operations.values():
if operation.type == 'MultiHeadAttention':
mha.append(operation)

for opr in mha:
assert isinstance(opr, Operation)
q_var = Variable(name=opr.name + '_fake_q_', source_op=opr)
k_var = Variable(name=opr.name + '_fake_k_', source_op=opr)
v_var = Variable(name=opr.name + '_fake_v_', source_op=opr)
energy_var = Variable(name=opr.name + '_fake_energy_', source_op=opr)
feat_var = Variable(name=opr.name + '_fake_feat_', source_op=opr)

opr.outputs.append(q_var)
opr.outputs.append(k_var)
opr.outputs.append(v_var)
opr.outputs.append(energy_var)
opr.outputs.append(feat_var)

self.graph.append_variable(q_var)
self.graph.append_variable(k_var)
self.graph.append_variable(v_var)
self.graph.append_variable(energy_var)
self.graph.append_variable(feat_var)

def replace_substraction(self) -> None:
substractions = []
for operation in self.graph.operations.values():
if operation.type == 'Sub':
Expand Down
3 changes: 3 additions & 0 deletions ppq/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def format_graph(graph: BaseGraph) -> BaseGraph:
在 PPQ 中,我们不希望出现 Constant 算子,所有 Constant 输入将被当作 parameter variable 连接到下游算子上
在 PPQ 中,我们不希望出现 Batchnorm 算子,所有 Batchnorm 将被合并
在 PPQ 中,我们不希望出现权重共享的算子,所有被共享的权重将被复制分裂成多份
在 PPQ 中,我们希望 MultiHeadAttention 算子中间过程可被量化
在 PPQ 中,我们不希望出现孤立算子,所有孤立算子将被移除

This function takes pre-processing procedure with your graph.
Expand All @@ -639,6 +640,7 @@ def format_graph(graph: BaseGraph) -> BaseGraph:
formatter(GraphCommand(GraphCommandType.FORMAT_CAST))
formatter(GraphCommand(GraphCommandType.FORMAT_SLICE))
formatter(GraphCommand(GraphCommandType.FORMAT_CLIP))
formatter(GraphCommand(GraphCommandType.FORMAT_MHA))
formatter(GraphCommand(GraphCommandType.DELETE_ISOLATED))

return graph
Expand All @@ -657,6 +659,7 @@ def dispatch_graph(graph: BaseGraph, platform: TargetPlatform, setting: Quantiza
"""
assert platform in QUANTIZER_COLLECTION, (
f'Platform misunderstood, except one of following platform {QUANTIZER_COLLECTION.keys()}')

quantizer = QUANTIZER_COLLECTION[platform](graph) # 初始化一个 quantizer 没有很大代价...

if str(setting.dispatcher).lower() not in DISPATCHER_TABLE:
Expand Down
4 changes: 4 additions & 0 deletions ppq/api/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,12 @@ def academic_setting() -> QuantizationSetting:
@staticmethod
def ncnn_setting() -> QuantizationSetting:
default_setting = QuantizationSetting()
default_setting.bias_correct = True
default_setting.fusion = False
default_setting.dispatcher = 'pointwise'

default_setting.quantize_activation_setting.calib_algorithm = None

return default_setting

@ staticmethod
Expand Down
2 changes: 1 addition & 1 deletion ppq/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
LINEAR_ACTIVATIONS = {'Relu', 'Clip'}

# COPUTING OP 是所有计算层,该属性被用于联合定点和子图切分
COMPUTING_OP = {'Conv', 'Gemm', 'ConvTranspose', 'MatMul'}
COMPUTING_OP = {'Conv', 'Gemm', 'ConvTranspose', 'MatMul', 'LayerNorm', 'MultiHeadAttention'}
# SOI OP 是所有产生 SOI 输出的节点类型,该属性被用于子图切分
SOI_OP = {'TopK', 'Shape', 'NonMaxSuppression'}
# 强制联合定点的算子种类
Expand Down
30 changes: 15 additions & 15 deletions ppq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,20 @@ def num_of_output(self):

def convert_any_to_python_primary_type(
x: Union[torch.Tensor, np.ndarray, int, float, list, str],
accepet_none: bool=True) -> Union[int, float, list, str]:
if x is None and accepet_none: return None
if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.')
accept_none: bool=True) -> Union[int, float, list, str]:
if x is None and accept_none: return None
if x is None and not accept_none: raise ValueError('Trying to convert an empty value.')
if isinstance(x, list) or isinstance(x, tuple): return list(x)
elif isinstance(x, int) or isinstance(x, float): return x
elif isinstance(x, torch.Tensor):
if x.numel() == 0 and accepet_none: return None
if x.numel() == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.')
if x.numel() == 0 and accept_none: return None
if x.numel() == 0 and not accept_none: raise ValueError('Trying to convert an empty value.')
if str(x.device) != 'cpu': x = x.cpu()
if x.numel() == 1: return x.item()
if x.numel() > 1: return x.tolist()
elif isinstance(x, np.ndarray):
if x.size == 0 and accepet_none: return None
if x.size == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.')
if x.size == 0 and accept_none: return None
if x.size == 0 and not accept_none: raise ValueError('Trying to convert an empty value.')
if x.size == 1: return x.reshape((1, )).tolist()[0]
if x.size > 1: return x.tolist()
elif isinstance(x, str):
Expand All @@ -236,14 +236,14 @@ def convert_any_to_python_primary_type(

def convert_any_to_numpy(
x: Union[torch.Tensor, np.ndarray, int, float, list, tuple],
accepet_none: bool=True) -> np.ndarray:
if x is None and accepet_none: return None
if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.')
accept_none: bool=True) -> np.ndarray:
if x is None and accept_none: return None
if x is None and not accept_none: raise ValueError('Trying to convert an empty value.')
if isinstance(x, np.ndarray): return x
elif isinstance(x, int) or isinstance(x, float): return np.array([x, ])
elif isinstance(x, torch.Tensor):
if x.numel() == 0 and accepet_none: return None
if x.numel() == 0 and not accepet_none: raise ValueError('Trying to convert an empty value.')
if x.numel() == 0 and accept_none: return None
if x.numel() == 0 and not accept_none: raise ValueError('Trying to convert an empty value.')
if x.numel() == 1: return convert_any_to_numpy(x.detach().cpu().item())
if x.numel() > 1: return x.detach().cpu().numpy()
elif isinstance(x, list) or isinstance(x, tuple):
Expand All @@ -254,9 +254,9 @@ def convert_any_to_numpy(

def convert_any_to_torch_tensor(
x: Union[torch.Tensor, np.ndarray, int, float, list, tuple],
accepet_none: bool=True, dtype: torch.dtype=None, device='cpu') -> torch.Tensor:
if x is None and accepet_none: return None
if x is None and not accepet_none: raise ValueError('Trying to convert an empty value.')
accept_none: bool=True, dtype: torch.dtype=None, device='cpu') -> torch.Tensor:
if x is None and accept_none: return None
if x is None and not accept_none: raise ValueError('Trying to convert an empty value.')
if isinstance(x, list) or isinstance(x, tuple):
if all([type(element) == int for element in x]):
if dtype is None: dtype=torch.int64
Expand Down
25 changes: 16 additions & 9 deletions ppq/core/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ class QuantizationProperty(Enum):
PER_TENSOR: Also known as per-layer quantization, which mean all parameters of this layer share the same scale and offset.
(For Convulution layer and Gemm layer which has bias, bias layer will be negative quantized, they do not have a valid scale)

PER_CHANNEL: parameters are quantized alone channel axis, each channel has a stand-alone scale and offset.
PER_CHANNEL: CNN model parameters are quantized along channel axis, each channel has a stand-alone scale and offset.

PER_CHANNEL_BNC: transformer model parameters are quantized per-chanel, each channel has a stand-alone scale and offset.

LINEAR: Indicates a linear quantization, follow formula: quant(x) = clip(round(x / scale))

Expand All @@ -151,18 +153,22 @@ class QuantizationProperty(Enum):
ASYMMETRICAL: Indicates an asymmetrical quantization, offset is activated in this mode.

POWER_OF_2: Indicates a power-of-2 quantization, scale must be pow(2, k) in this mode.

PTF_BNC: Indicates a power-of-2 for quantization for layernorm input, scale must be pow(2, k) in this mode.

ATTENTION: Not all combinations of all 7 QuantizationProperty are valid, see QuantizationPolicy.__check_valid
ATTENTION: QuantizationPolicy is read-only, user can only assign its value when created, the only interface of
QuantizationPolicy is function QuantizationPolicy.has_property.
"""
PER_TENSOR = 0x00000001
PER_CHANNEL = 0x00000002
LINEAR = 0x00000004
EXPONENTIAL = 0x00000008
SYMMETRICAL = 0x00000010
ASYMMETRICAL = 0x00000020
POWER_OF_2 = 0x00000040
PER_TENSOR = 0x00000001
PER_CHANNEL = 0x00000002
PER_CHANNEL_BNC = 0x00000080
LINEAR = 0x00000004
EXPONENTIAL = 0x00000008
SYMMETRICAL = 0x00000010
ASYMMETRICAL = 0x00000020
POWER_OF_2 = 0x00000040
PTF_BNC = 0x00000100

def __or__(self, other: int) -> int:
return self.value + other
Expand Down Expand Up @@ -247,6 +253,7 @@ def __check_valid(cls, policy):
QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_TENSOR | QuantizationProperty.POWER_OF_2,
QuantizationProperty.ASYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL | QuantizationProperty.POWER_OF_2,
QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL | QuantizationProperty.POWER_OF_2,
QuantizationProperty.SYMMETRICAL | QuantizationProperty.LINEAR | QuantizationProperty.PER_CHANNEL_BNC | QuantizationProperty.PTF_BNC,
}

def to_dict(self) -> dict:
Expand Down Expand Up @@ -658,7 +665,7 @@ class ChannelwiseTensorQuantizationConfig(TensorQuantizationConfig):
"""ChannelwiseTensorQuantizationConfig is a special case for tensor
quantization configuration.

Comparing with per-tensor quantization configuration, pre-channel quantization has a property
Comparing with per-tensor quantization configuration, per-channel quantization has a property
"channel_axis" to indicate a channel axis where quantization takes effects.

Along this axis, all tensor values will be quantized with a sharing scale and offset,
Expand Down
7 changes: 5 additions & 2 deletions ppq/csrc/cuda/PPQ.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,16 @@
ATTENTION: QuantizationPolicy is read-only, user can only assign its value when created, the only interface of
QuantizationPolicy is function QuantizationPolicy.has_property.
*/
# define PPQ_QPROPERTY_PER_TENSOR 0x00000001
# define PPQ_QPROPERTY_PER_CHANNEL 0x00000002
# define PPQ_QPROPERTY_PER_TENSOR 0x00000001
# define PPQ_QPROPERTY_PER_CHANNEL 0x00000002
# define PPQ_QPROPERTY_PER_CHANNEL_BNC 0x00000080

# define PPQ_QPROPERTY_LINEAR 0x00000004
# define PPQ_QPROPERTY_EXPONENTIAL 0x00000008
# define PPQ_QPROPERTY_SYMMETRICAL 0x00000010
# define PPQ_QPROPERTY_ASYMMETRICAL 0x00000020
# define PPQ_QPROPERTY_POWER_OF_2 0x00000040
# define PPQ_QPROPERTY_PTF_BNC 0x00000100

# define PPQ_CPP_EXTENSION
# endif
Expand Down
Loading