Since the shape of weights is known in advance, the MatMul weights can be created with format tag dnnl::memory::format_tag::any to enable the library to choose the most appropriate layout for best performance.
#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>
#include "example_utils.hpp"
namespace {
void init_vector(std::vector<float> &v) {
std::mt19937 gen;
std::uniform_real_distribution<float> u(0, 1);
for (auto &e : v)
e = u(gen);
}
void init_vector(std::vector<uint8_t> &v) {
std::mt19937 gen;
std::uniform_int_distribution<unsigned int> u(0, 255);
for (auto &e : v)
e = static_cast<uint8_t>(u(gen));
}
}
int number_of_runs = 1;
int64_t K, int64_t N,
const engine &eng) {
attr.set_post_ops(po);
}
std::vector<uint8_t> A_u8(M * K);
init_vector(A_u8);
std::vector<float> scales_f32(N);
init_vector(scales_f32);
int32_t zp_A = 128, zp_C = 40;
write_to_dnnl_memory(A_u8.data(), A_u8_mem);
write_to_dnnl_memory(&zp_A, zp_A_mem);
write_to_dnnl_memory(&zp_C, zp_C_mem);
write_to_dnnl_memory(scales_f32.data(), scale_f32_mem);
}
int32_t zp_C = 0;
std::vector<uint8_t> C_u8(M * N);
read_from_dnnl_memory(C_u8.data(), C_u8_mem);
read_from_dnnl_memory(&zp_C, zp_C_mem);
for (int64_t i = 0; i < M * N; ++i)
if (C_u8[i] < zp_C)
throw std::logic_error(
"Smoke check failed."
"\n\tQuantized value is smaller than the zero point,"
"\n\twhich should not happen since ReLU was applied.");
}
void infer(
const matmul &matmul_p, int64_t M, int64_t N, int64_t K,
prepare_input(A_u8_mem, scale_f32_mem, zp_A_mem, zp_C_mem);
for (int run = 0; run < number_of_runs; ++run)
{{DNNL_ARG_SRC, A_u8_mem}, {DNNL_ARG_WEIGHTS, B_s8_mem},
{DNNL_ARG_DST, C_u8_mem},
{DNNL_ARG_ATTR_OUTPUT_SCALES, scale_f32_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zp_A_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zp_C_mem}});
s.wait();
sanity_check(C_u8_mem, zp_C_mem);
}
const int64_t K = 96;
const int64_t N = 1000;
auto matmul_pd = matmul_pd_create(K, N, eng);
std::vector<float> B_f32(K * N);
init_vector(B_f32);
memory B_s8_mem(matmul_pd.weights_desc(), eng);
{
write_to_dnnl_memory(B_f32.data(), B_f32_mem);
s.wait();
}
for (int64_t M : {1, 100})
infer(matmul_p, M, N, K, B_s8_mem, eng);
}
int main(int argc, char **argv) {
return handle_example_errors(inference_int8_matmul, engine_kind);
}
@ eltwise_relu
Elementwise: rectified linear unit (ReLU)
Definition dnnl.hpp:490
#define DNNL_RUNTIME_S32_VAL
A wildcard value for int32_t values that are unknown at a primitive creation time.
Definition dnnl_types.h:941
#define DNNL_RUNTIME_DIM_VAL
A wildcard value for dimensions that are unknown at a primitive creation time.
Definition dnnl_types.h:916
#define DNNL_RUNTIME_F32_VAL
A wildcard value for floating point values that are unknown at a primitive creation time.
Definition dnnl_types.h:933
#define DNNL_ARG_DST
A special mnemonic for destination argument for primitives that have a single destination.
Definition dnnl_types.h:1806
#define DNNL_ARG_SRC
A special mnemonic for source argument for primitives that have a single source.
Definition dnnl_types.h:1782
@ matmul_d
matmul descriptor
Definition dnnl.hpp:788
oneDNN namespace
Definition dnnl.hpp:81
An execution engine.
Definition dnnl.hpp:844
kind
Kinds of engines.
Definition dnnl.hpp:849
Descriptor for a matmul primitive.
Definition dnnl.hpp:9994
Primitive descriptor for a matmul primitive.
Definition dnnl.hpp:10041
Matrix multiplication (matmul) primitive.
Definition dnnl.hpp:9992
A memory descriptor.
Definition dnnl.hpp:1729
memory::dims dims() const
Returns dimensions of the memory descriptor.
Definition dnnl.hpp:1930
Memory object.
Definition dnnl.hpp:1188
@ any
Placeholder memory format tag.
Definition dnnl.hpp:1287
@ ab
plain 2D tensor
Definition dnnl.hpp:1293
@ u8
8-bit unsigned integer.
Definition dnnl.hpp:1222
@ s8
8-bit signed integer.
Definition dnnl.hpp:1220
@ f32
32-bit/single-precision floating point.
Definition dnnl.hpp:1216
@ s32
32-bit signed integer.
Definition dnnl.hpp:1218
desc get_desc() const
Returns the associated memory descriptor.
Definition dnnl.hpp:2010
Post-ops.
Definition dnnl.hpp:2205
void append_eltwise(float scale, algorithm algorithm, float alpha, float beta)
Appends an elementwise post-op.
Definition dnnl.hpp:2280
Primitive attributes.
Definition dnnl.hpp:2481
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scaling factors correspondence mask and values.
Definition dnnl.hpp:2583
void execute(const stream &stream, const std::unordered_map< int, memory > &args) const
Executes computations specified by the primitive in a specified stream.
Reorder primitive.
Definition dnnl.hpp:3118
void execute(const stream &stream, memory &src, memory &dst) const
Executes the reorder primitive.
Definition dnnl.hpp:3227
An execution stream.
Definition dnnl.hpp:1047