Skip to content

FP8 ONNX导出问题 #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c9e8504
test
wangxc2006 Oct 9, 2022
0491be5
support sophgo_tpu backend,initial ver
wangxc2006 Oct 9, 2022
ef3f30e
update, fix some bug
wangxc2006 Nov 3, 2022
96683b6
add deconv and refine SophgoTpuQuantizer
wangxc2006 Nov 3, 2022
e612cfd
add some verfiy code
wangxc2006 Nov 15, 2022
99775c4
for qat int8 release
wangxc2006 Dec 1, 2022
252ede6
for qat int8 release
wangxc2006 Dec 5, 2022
cfddb1d
fix linear+bn bug
wangxc2006 Feb 27, 2023
21f39c0
add int4&int8 mix prec func and infer net output shape in xxx_mqmoble…
wangxc2006 Apr 12, 2023
232ca1f
fix int8 bug in int4 version
wangxc2006 Apr 12, 2023
1d09bc4
fix sub/abs op no fake quant node
wangxc2006 Jun 21, 2023
1f43b9d
add some class and func to adapt to torch1.10_cpu and torch2.0.1_cpu
zhengjin-xu11 Aug 2, 2023
47c3cde
commit message here
shee-gao Aug 14, 2023
2612562
commit message here
shee-gao Aug 15, 2023
4642917
qat gpt2
shee-gao Aug 29, 2023
99f2a16
添加了FP8 fakequant以及修改了config以及prepare_by_platform中的一些问题”
tjthereal Aug 30, 2023
e57a418
QAT example
shee-gao Aug 30, 2023
459ed0b
[Feature] NLP trace support and example
shee-gao Aug 31, 2023
a6b2c12
hide STOCHASTIC
tjthereal Sep 5, 2023
505b832
Add GPTQ
yz-zhang-sg Aug 4, 2023
96caec1
Update gptq.py
yz-zhang-sg Aug 8, 2023
de5a1c1
Add bert-base-uncased-mrpc GPTQ version
yz-zhang-sg Aug 8, 2023
892331d
Fixed two devices problem
yz-zhang-sg Aug 9, 2023
73fc14e
Update gptq.py
yz-zhang-sg Aug 9, 2023
02c248f
gptq use academic backend
yz-zhang-sg Aug 9, 2023
dc22db9
fix model insert & delete info | fix deploy
yz-zhang-sg Aug 14, 2023
1530d26
nlp deploy
shee-gao Sep 12, 2023
5c6e623
1.Correcting the method for registering operators in MQBench.
zhengjin-xu11 Sep 12, 2023
38a0c54
Complete merge
shee-gao Sep 12, 2023
195b522
conflict fixed
zhengjin-xu11 Sep 12, 2023
d489ab7
add fast test before push
zhengjin-xu11 Sep 14, 2023
4290bd7
FP8 ONNX problem
tjthereal Sep 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions FP8_Emulator/cmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import simple
219 changes: 219 additions & 0 deletions FP8_Emulator/cmodel/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#------------------------------------------------------------------------------
# Copyright (c) 2023, Intel Corporation - All rights reserved.
# This file is part of FP8-Emulation-Toolkit
#
# SPDX-License-Identifier: BSD-3-Clause
#------------------------------------------------------------------------------
# Naveen Mellempudi (Intel Corporation)
#------------------------------------------------------------------------------

import torch
from torch import nn
from torch.autograd import Function
import sys

import simple_gemm_dev
import simple_conv2d_dev
# backup the original torch functions
fallback_addmm = torch.addmm
fallback_matmul = torch.matmul
fallback_mm = torch.mm

def is_transposed(input):
if input.is_contiguous():
return input, False
elif input.t().is_contiguous():
return input.t(), True
else:
return input.contiguous(), False

def addmm(input, mat1, mat2, beta=1.0, alpha=1.0, out=None):
if input.dtype == torch.float32 and mat1.dtype == torch.float32 and \
mat1.dim() == 2 and mat2.dim() == 2 and mat1.size(1) == mat2.size(0):
if out:
output = out
else:
output = torch.zeros([mat1.size(0), mat2.size(1)])
a_mat, a_trans = is_transposed(mat1)
b_mat, b_trans = is_transposed(mat2)
output = SimpleAddmm.apply(output, input, a_mat, b_mat, alpha, beta, a_trans, b_trans)
ret = output
else:
warnings.warn('simple.addmm does not support the input dimensions - input :{}, mat1: {}, mat2: {}, falling back to torch.addmm'.format(
input.size(), mat1.size(), mat2.size()))
ret = fallback_addmm(input, mat1, mat2, beta=beta, alpha=alpha, out=out)
return ret

def matmul(input, other, out=None):
if input.dtype == torch.float32 and other.dtype == torch.float32 and \
input.dim() == 2 and other.dim() == 2 and input.size(1) == other.size(0):
if out:
output = out
else:
output = torch.zeros([input.size(0), other.size(1)])
a_mat, a_trans = is_transposed(input)
b_mat, b_trans = is_transposed(other)
output = SimpleMatmul.apply(output, a_mat, b_mat, 1.0, a_trans, b_trans)
return output
# Batch MatMul implementation
elif input.dtype == torch.float32 and other.dtype == torch.float32 and \
input.dim() == 3 and other.dim() == 2 and input.size(2) == other.size(0):
if out:
output = out
else:
output = torch.zeros([input.size(0), input.size(1), other.size(1)])
a_mat, a_trans = is_transposed(input)
b_mat, b_trans = is_transposed(other)
output = torch.stack(tuple([SimpleMatmul.apply(out1, a_mat1, b_mat, 1.0, a_trans, b_trans) \
for a_mat1, out1 in zip(a_mat, output)]))
return output
else:
warnings.warn('simple.matmul does not support the input dimensions - input :{}, other: {}, falling back to torch.matmul'.format(
input.size(), other.size()))
return fallback_matmul(input, other, out=out)

def mm(input, mat2, out=None):
if input.dtype == torch.float32 and mat2.dtype == torch.float32 and \
input.dim() == 2 and mat2.dim() == 2 and input.size(1) == mat2.size(0):
if out:
output = out
else:
output = torch.zeros([input.size(0), mat2.size(1)])
a_mat, a_trans = is_transposed(input)
b_mat, b_trans = is_transposed(mat2)
output = SimpleMatmul.apply(output, a_mat, b_mat, 1.0, a_trans, b_trans)
return output
else:
warnings.warn('simple.mm does not support the input dimensions - input :{}, mat2: {}, falling back to torch.mm'.format(
input.size(), mat2.size()))
return fallback_mm(input, mat2, out=out)

def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
N = input.size()[0]
C = input.size()[1]
H = input.size()[2]
W = input.size()[3]
K = weight.size()[0]
C1 = weight.size()[1]
R = weight.size()[2]
S = weight.size()[3]

if dilation[0] > 1:
sys.exit("ERROR: simple_conv2d does not support dilated convolutions.")
if padding[0] != padding[1]:
sys.exit("ERROR: simple_conv2d does not support non-uniform padding; pad_h must be equal to pad_w.")
if groups > 1:
sys.exit("ERROR: simple_conv2d does not support grouped convolutions; set groups to 1.")

H_out = ((H + (2*padding[0]) - dilation[0] * (R-1) -1)/stride[0]) + 1
W_out = ((W + (2*padding[1]) - dilation[1] * (S-1) -1)/stride[1]) + 1
output = torch.empty([N, K, int(H_out), int(W_out)])
output = SimpleConv2dFunction.apply(output, input, weight, bias, stride, padding, dilation, groups)
return output

class SimpleAddmm(Function):
@staticmethod
def forward(ctx, output, input, mat1, mat2, alpha, beta, a_trans, b_trans):
ctx.save_for_backward(mat1, mat2)
ctx.a_trans = a_trans
ctx.b_trans = b_trans
ctx.alpha = alpha

simple_gemm_dev.gemm(output, mat1, mat2, alpha, a_trans, b_trans)
output += beta * input;
ctx.mark_dirty(output)
return output

@staticmethod
def backward(ctx, grad_output):
mat1, mat2 = ctx.saved_tensors

alpha = ctx.alpha
a_trans = ctx.a_trans
b_trans = ctx.b_trans

grad_mat1 = torch.zeros_like(mat1)
grad_mat2 = torch.zeros_like(mat2)
grad_out, out_trans = is_transposed(grad_output)

if a_trans:
simple_gemm_dev.gemm(grad_mat1, mat2, grad_out, alpha, b_trans, not out_trans)
else:
simple_gemm_dev.gemm(grad_mat1, grad_out, mat2, alpha, out_trans, not b_trans)

if b_trans:
simple_gemm_dev.gemm(grad_mat2, grad_out, mat1, alpha, not out_trans, a_trans)
else:
simple_gemm_dev.gemm(grad_mat2, mat1, grad_out, alpha, not a_trans, out_trans)

return (grad_output, grad_output, grad_mat1, grad_mat2, None, None, None, None)

class SimpleMatmul(Function):
@staticmethod
def forward(ctx, output, mat1, mat2, alpha, a_trans, b_trans):
ctx.save_for_backward(mat1, mat2)
ctx.a_trans = a_trans
ctx.b_trans = b_trans
ctx.alpha = alpha

simple_gemm_dev.gemm(output, mat1, mat2, alpha, a_trans, b_trans)
ctx.mark_dirty(output)
return output

@staticmethod
def backward(ctx, grad_output):
mat1, mat2 = ctx.saved_tensors
alpha = ctx.alpha
a_trans = ctx.a_trans
b_trans = ctx.b_trans

grad_mat1 = torch.empty_like(mat1)
grad_mat2 = torch.empty_like(mat2)
grad_out, out_trans = is_transposed(grad_output)

if a_trans:
simple_gemm_dev.gemm(grad_mat1, mat2, grad_out, alpha, b_trans, not out_trans)
else:
simple_gemm_dev.gemm(grad_mat1, grad_out, mat2, alpha, out_trans, not b_trans)

if b_trans:
simple_gemm_dev.gemm(grad_mat2, grad_out, mat1, alpha, not out_trans, a_trans)
else:
simple_gemm_dev.gemm(grad_mat2, mat1, grad_out, alpha, not a_trans, out_trans)
return (grad_output, grad_mat1, grad_mat2, None, None, None)


class SimpleConv2dFunction(Function):
@staticmethod
def forward(ctx, output, inputs, weights, bias, stride, padding, dilation, groups):
#print("### conv2d fwd called input size: ", inputs.size(), weights.size(), stride, padding, dilation, groups)
ctx.save_for_backward(inputs, weights)#, bias)
ctx.stride = stride#[0]
ctx.padding = padding#[0]
ctx.dilation = dilation#[0]
ctx.groups = groups

if bias is None:
bias_fw = torch.zeros(output.size()[1])
else :
bias_fw = bias

simple_conv2d_dev.conv2d_fp(output, inputs, weights, bias_fw, stride[0], padding[0], dilation[0], groups)
ctx.mark_dirty(output)
return output

@staticmethod
def backward(ctx, grad_output):
#inputs, weights, bias = ctx.saved_tensors
inputs, weights = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
dilation = ctx.dilation
groups = ctx.groups
#print("### conv2d bwd called input size: ", inputs.size(), weights.size(), stride, padding, dilation, groups)
grad_inp = torch.zeros_like(inputs)
grad_wts = torch.zeros_like(weights)

simple_conv2d_dev.conv2d_bp(grad_inp, grad_output, weights, stride[0], padding[0], dilation[0], groups)
simple_conv2d_dev.conv2d_wu(grad_wts, grad_output, inputs, stride[0], padding[0], dilation[0], groups)
return (grad_output, grad_inp, grad_wts, None, None, None, None, None)
127 changes: 127 additions & 0 deletions FP8_Emulator/cmodel/simple/simple_conv2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*----------------------------------------------------------------------------*
* Copyright (c) 2023, Intel Corporation - All rights reserved.
* This file is part of FP8-Emulation-Toolkit
*
* SPDX-License-Identifier: BSD-3-Clause
*----------------------------------------------------------------------------*
* Naveen Mellempudi (Intel Corporation)
*----------------------------------------------------------------------------*/

#include <torch/extension.h>
#include <vector>
#include <iostream>
#include <time.h>
#include <immintrin.h>
#include <sys/syscall.h>
#include <omp.h>

extern int simple_conv2d_impl_fp(float* outputs, float *inputs, float *weights, float* bias, int N, int C, int iH, int iW,
int K, int R, int S, int stride, int padding, int dilation, int groups);
extern int simple_conv2d_impl_bp(float* inputs, float *outputs, float *weights, int N, int C, int iH, int iW,
int K, int R, int S, int stride, int padding, int dilation, int groups);
extern int simple_conv2d_impl_wu(float *weights, float *outputs, float *inputs, int N, int C, int iH, int iW,
int K, int R, int S, int stride, int padding, int dilation, int groups);

#define gettid() ((int)syscall(SYS_gettid))

using namespace torch::autograd::profiler;


double get_time() {
static bool init_done = false;
static struct timespec stp = {0,0};
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);

if(!init_done) {
init_done = true;
stp = tp;
}
double ret = (tp.tv_sec - stp.tv_sec) * 1e3 + (tp.tv_nsec - stp.tv_nsec)*1e-6;
return ret;
}

at::Tensor simple_conv2d_fp(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, torch::Tensor bias,
int stride, int padding, int dilation, int groups)
{
RECORD_FUNCTION("simple_conv2d_fp", std::vector<c10::IValue>({input, weight, bias}));

auto N = input.size(0);
auto C = input.size(1);
auto H = input.size(2);
auto W = input.size(3);

auto K = weight.size(0);
//auto C1 = weight.size(1);
auto R = weight.size(2);
auto S = weight.size(3);

float *input_ptr = input.data_ptr<float>();
float *weight_ptr = weight.data_ptr<float>();
float *output_ptr = output.data_ptr<float>();
float *bias_ptr = bias.data_ptr<float>();

simple_conv2d_impl_fp(output_ptr, input_ptr, weight_ptr, bias_ptr, N, C, H, W,
K, R, S, stride, padding, dilation, groups);

//thnn_conv2d_out(output, input, weight,
return output;
}

at::Tensor simple_conv2d_bp(torch::Tensor& input, torch::Tensor output, torch::Tensor weight,
int stride, int padding, int dilation, int groups)
{
RECORD_FUNCTION("simple_conv2d_bp", std::vector<c10::IValue>({output, weight}));

auto N = input.size(0);
auto C = input.size(1);
auto H = input.size(2);
auto W = input.size(3);

auto K = weight.size(0);
//auto C1 = weight.size(1);
auto R = weight.size(2);
auto S = weight.size(3);

float *input_ptr = input.data_ptr<float>();
float *weight_ptr = weight.data_ptr<float>();
float *output_ptr = output.data_ptr<float>();

simple_conv2d_impl_bp(input_ptr, output_ptr, weight_ptr, N, C, H, W,
K, R, S, stride, padding, dilation, groups);

//thnn_conv2d_out(output, input, weight,
return input;
}

at::Tensor simple_conv2d_wu(torch::Tensor& weight, torch::Tensor output, torch::Tensor input,
int stride, int padding, int dilation, int groups)
{
RECORD_FUNCTION("simple_conv2d_wu", std::vector<c10::IValue>({output, input}));

auto N = input.size(0);
auto C = input.size(1);
auto H = input.size(2);
auto W = input.size(3);

auto K = weight.size(0);
//auto C1 = weight.size(1);
auto R = weight.size(2);
auto S = weight.size(3);

float *input_ptr = input.data_ptr<float>();
float *weight_ptr = weight.data_ptr<float>();
float *output_ptr = output.data_ptr<float>();

simple_conv2d_impl_wu(weight_ptr, output_ptr, input_ptr, N, C, H, W,
K, R, S, stride, padding, dilation, groups);

//thnn_conv2d_out(output, input, weight,
return weight;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("conv2d_fp", &simple_conv2d_fp, "simple conv_fp implementation");
m.def("conv2d_bp", &simple_conv2d_bp, "simple conv_bp implementation");
m.def("conv2d_wu", &simple_conv2d_wu, "simple conv_wu implementation");
}
Loading