Users are free to choose between these two options, as well as any intermediate ones (e.g., specifying some of the parameters at creation time while leaving the others until execution time). This enables balancing between flexibility and performance.
#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>
#include "example_utils.hpp"
namespace {
void init_vector(std::vector<float> &v) {
std::mt19937 gen;
std::uniform_real_distribution<float> u(-1, 1);
for (auto &e : v)
e = u(gen);
}
int compare_vectors(const std::vector<float> &v1, const std::vector<float> &v2,
int64_t K, const char *message) {
double v1_l2 = 0, diff_l2 = 0;
for (size_t n = 0; n < v1.size(); ++n) {
float diff = v1[n] - v2[n];
v1_l2 += v1[n] * v1[n];
diff_l2 += diff * diff;
}
v1_l2 = std::sqrt(v1_l2);
diff_l2 = std::sqrt(diff_l2);
const double threshold = std::numeric_limits<float>::epsilon()
* std::log(std::max(2., (double)K));
bool ok = diff_l2 <= threshold * v1_l2;
printf("%s\n\tL2 Norms"
"\n\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative_error:%g\n"
"\tAccuracy check: %s\n",
message, v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED");
return ok ? 0 : 1;
}
}
int number_of_runs = 1;
float fixed_beta = 0.f;
matmul dynamic_matmul_create() {
float beta = fixed_beta;
if (beta != 0.f) {
}
}
void dynamic_matmul_execute(
matmul &matmul_p,
char transA,
char transB,
int64_t M, int64_t N, int64_t K, float alpha, const float *A,
int64_t lda, const float *B, int64_t ldb, float beta, float *C,
int64_t ldc) {
if (beta != fixed_beta)
throw std::logic_error("Run-time beta is not yet supported.");
dims a_strides = tolower(transA) == 'n' ? dims {lda, 1} : dims {1, lda};
dims b_strides = tolower(transB) == 'n' ? dims {ldb, 1} : dims {1, ldb};
s.wait();
}
void static_matmul_create_and_execute(char transA, char transB, int64_t M,
int64_t N, int64_t K, float alpha, const float *A, int64_t lda,
const float *B, int64_t ldb, float beta, float *C, int64_t ldc) {
dims a_strides = tolower(transA) == 'n' ? dims {lda, 1} : dims {1, lda};
dims b_strides = tolower(transB) == 'n' ? dims {ldb, 1} : dims {1, ldb};
if (beta != 0.f) {
attr.set_post_ops(po);
}
memory A_m(a_md, eng, (
void *)A);
memory B_m(b_md, eng, (
void *)B);
memory C_m(c_md, eng, (
void *)C);
s.wait();
}
void sgemm_and_matmul_with_params(char transA, char transB, int64_t M,
int64_t N, int64_t K, float alpha, float beta) {
if (beta != fixed_beta)
throw std::logic_error("Run-time beta is not yet supported.");
std::vector<float> A(M * K);
init_vector(A);
std::vector<float> B(K * N);
init_vector(B);
std::vector<float> C_sgemm(M * N);
init_vector(C_sgemm);
std::vector<float> C_dynamic_matmul = C_sgemm;
std::vector<float> C_static_matmul = C_sgemm;
int64_t lda = tolower(transA) == 'n' ? K : M;
int64_t ldb = tolower(transB) == 'n' ? N : K;
int64_t ldc = N;
for (int run = 0; run < number_of_runs; ++run)
dnnl_sgemm(transA, transB, M, N, K, alpha, A.data(), lda, B.data(), ldb,
beta, C_sgemm.data(), ldc);
auto dynamic_matmul = dynamic_matmul_create();
for (int run = 0; run < number_of_runs; ++run)
dynamic_matmul_execute(dynamic_matmul, transA, transB, M, N, K, alpha,
A.data(), lda, B.data(), ldb, beta, C_dynamic_matmul.data(),
ldc);
for (int run = 0; run < number_of_runs; ++run)
static_matmul_create_and_execute(transA, transB, M, N, K, alpha,
A.data(), lda, B.data(), ldb, beta, C_static_matmul.data(),
ldc);
int rc = 0;
rc |= compare_vectors(
C_sgemm, C_dynamic_matmul, K, "Compare SGEMM vs dynamic MatMul");
if (rc) throw std::logic_error("The resulting matrices diverged too much.");
rc |= compare_vectors(
C_sgemm, C_static_matmul, K, "Compare SGEMM vs static MatMul");
if (rc) throw std::logic_error("The resulting matrices diverged too much.");
}
void sgemm_and_matmul() {
sgemm_and_matmul_with_params('N', 'T', 10, 20, 30, 1.1f, fixed_beta);
}
int main(int argc, char **argv) {
}
dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
#define DNNL_RUNTIME_DIM_VAL
A wildcard value for dimensions that are unknown at a primitive creation time.
Definition dnnl_types.h:916
#define DNNL_RUNTIME_F32_VAL
A wildcard value for floating point values that are unknown at a primitive creation time.
Definition dnnl_types.h:933
#define DNNL_ARG_ATTR_OUTPUT_SCALES
Output scaling factors provided at execution time.
Definition dnnl_types.h:1945
#define DNNL_ARG_DST
A special mnemonic for destination argument for primitives that have a single destination.
Definition dnnl_types.h:1806
#define DNNL_ARG_SRC
A special mnemonic for source argument for primitives that have a single source.
Definition dnnl_types.h:1782
#define DNNL_ARG_WEIGHTS
A special mnemonic for primitives that have a single weights argument.
Definition dnnl_types.h:1829
@ matmul_d
matmul descriptor
Definition dnnl.hpp:788
oneDNN namespace
Definition dnnl.hpp:81
An execution engine.
Definition dnnl.hpp:844
@ cpu
CPU engine.
Definition dnnl.hpp:853
Descriptor for a matmul primitive.
Definition dnnl.hpp:9994
Primitive descriptor for a matmul primitive.
Definition dnnl.hpp:10041
Matrix multiplication (matmul) primitive.
Definition dnnl.hpp:9992
A memory descriptor.
Definition dnnl.hpp:1729
Memory object.
Definition dnnl.hpp:1188
@ f32
32-bit/single-precision floating point.
Definition dnnl.hpp:1216
std::vector< dim > dims
Vector of dimensions.
Definition dnnl.hpp:1193
Post-ops.
Definition dnnl.hpp:2205
void append_sum(float scale=1.)
Appends an accumulation (sum) post-op.
Definition dnnl.hpp:2251
Primitive attributes.
Definition dnnl.hpp:2481
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scaling factors correspondence mask and values.
Definition dnnl.hpp:2583
void set_post_ops(const post_ops ops)
Sets post-ops.
Definition dnnl.hpp:2711
void execute(const stream &stream, const std::unordered_map< int, memory > &args) const
Executes computations specified by the primitive in a specified stream.
An execution stream.
Definition dnnl.hpp:1047