oneAPI Deep Neural Network Library (oneDNN)  1.4.0
Performance library for Deep Learning
Loading...
Searching...
No Matches
dnnl.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright 2016-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#ifndef DNNL_HPP
21#define DNNL_HPP
22
23#include "dnnl_config.h"
24
26#include <algorithm>
27#include <cstdlib>
28#include <iterator>
29#include <memory>
30#include <string>
31#include <vector>
32#include <unordered_map>
33
34#include "dnnl.h"
35
36#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
37#include "dnnl_threadpool_iface.hpp"
38#endif
39
40#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
41#include <CL/cl.h>
42#endif
44
45// __cpp_exceptions is referred from
46// https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html
47// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
48// Microsoft C++ Compiler does not provide an option to disable exceptions
49#ifndef DNNL_ENABLE_EXCEPTIONS
50#if __cpp_exceptions || __EXCEPTIONS \
51 || (defined(_MSC_VER) && !defined(__clang__))
52#define DNNL_ENABLE_EXCEPTIONS 1
53#else
54#define DNNL_ENABLE_EXCEPTIONS 0
55#endif
56#endif
57
58#if defined(__GNUC__) || defined(__clang__)
59#define DNNL_TRAP() __builtin_trap()
60#elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
61#define DNNL_TRAP() __debugbreak()
62#else
63#error "unknown compiler"
64#endif
65
66#if DNNL_ENABLE_EXCEPTIONS
67#define DNNL_THROW_ERROR(status, msg) throw error(status, msg)
68#else
69#include <cstdio>
70#define DNNL_THROW_ERROR(status, msg) \
71 do { \
72 fputs(msg, stderr); \
73 DNNL_TRAP(); \
74 } while (0)
75#endif
76
79
81namespace dnnl {
82
86
91struct error : public std::exception {
92 dnnl_status_t status;
93 const char *message;
94
99 error(dnnl_status_t status, const char *message)
100 : status(status), message(message) {}
101
103 const char *what() const noexcept override { return message; }
104
110 static void wrap_c_api(dnnl_status_t status, const char *message) {
111 if (status != dnnl_success) DNNL_THROW_ERROR(status, message);
112 }
113};
114
116template <typename T>
117void validate_container_size(const T &v, const char *error_message,
118 int min_size = 1, int max_size = -1) {
119 const int size = (int)v.size();
120 if (size < min_size || (max_size >= 0 && size > max_size))
121 DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
122}
124
126template <typename T>
128
142template <typename T, typename traits = handle_traits<T>>
143struct handle {
144private:
145 static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
146 std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
147
148protected:
149 bool operator==(const T other) const { return other == data_.get(); }
150 bool operator!=(const T other) const { return !(*this == other); }
151
152public:
160 handle() = default;
161
163 handle(const handle<T, traits> &) = default;
170
176 explicit handle(T t, bool weak = false) { reset(t, weak); }
177
183 void reset(T t, bool weak = false) {
184 data_.reset(t, weak ? &dummy_destructor : traits::destructor);
185 }
186
192 T get(bool allow_empty = false) const {
193 T result = data_.get();
194 if (allow_empty == false && result == nullptr)
195 DNNL_THROW_ERROR(
196 dnnl_invalid_arguments, "object is not initialized");
197 return result;
198 }
199
204 explicit operator T() const { return get(true); }
205
209 explicit operator bool() const { return get(true) != nullptr; }
210
217 bool operator==(const handle<T, traits> &other) const {
218 return other.data_.get() == data_.get();
219 }
220
227 bool operator!=(const handle &other) const { return !(*this == other); }
228};
229
231template <>
232struct handle_traits<dnnl_memory_t> {
233 static dnnl_status_t destructor(dnnl_memory_t p) {
234 return dnnl_memory_destroy(p);
235 }
236};
237
238template <>
240 static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
242 }
243};
244
245template <>
247 static dnnl_status_t destructor(dnnl_primitive_t p) {
248 return dnnl_primitive_destroy(p);
249 }
250};
251
252template <>
254 static dnnl_status_t destructor(dnnl_primitive_desc_iterator_t p) {
256 }
257};
259
261
262struct stream;
263struct error;
264struct memory;
265struct primitive_desc;
266
271
275
364
370 return static_cast<dnnl_primitive_kind_t>(kind);
371}
372
376 "could not get a primitive descriptor from a primitive");
377 return pd;
378}
379
382 // TODO (Roma): the code below is only needed because get_primitive_desc
383 // returns a C type.
386 pd, dnnl_query_primitive_kind, 0, (void *)&kind),
387 "could not get a primitive kind from a primitive descriptor");
388 return static_cast<dnnl::primitive::kind>(kind);
389}
390
392
404
430
436 return static_cast<dnnl_scratchpad_mode_t>(mode);
437}
438
465
471 return static_cast<dnnl_prop_kind_t>(kind);
472}
473
475enum class algorithm {
477 undef = dnnl_alg_kind_undef,
575};
576
583
585
588
619
626
628
631
633enum class rnn_flags : unsigned {
636};
637
642 return static_cast<dnnl_rnn_flags_t>(flags);
643}
644
645#define DNNL_DEFINE_BITMASK_OPS(enum_name) \
646 inline enum_name operator|(enum_name lhs, enum_name rhs) { \
647 return static_cast<enum_name>( \
648 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
649 } \
650\
651 inline enum_name operator&(enum_name lhs, enum_name rhs) { \
652 return static_cast<enum_name>( \
653 static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
654 } \
655\
656 inline enum_name operator^(enum_name lhs, enum_name rhs) { \
657 return static_cast<enum_name>( \
658 static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
659 } \
660\
661 inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
662 lhs = static_cast<enum_name>( \
663 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
664 return lhs; \
665 } \
666\
667 inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
668 lhs = static_cast<enum_name>( \
669 static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
670 return lhs; \
671 } \
672\
673 inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
674 lhs = static_cast<enum_name>( \
675 static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
676 return lhs; \
677 } \
678\
679 inline enum_name operator~(enum_name rhs) { \
680 return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
681 }
682
683DNNL_DEFINE_BITMASK_OPS(normalization_flags)
684DNNL_DEFINE_BITMASK_OPS(rnn_flags)
685
686
701
706 return static_cast<dnnl_rnn_direction_t>(dir);
707}
708
710
713
720enum class query {
723
728
733
742
747
752
755
758
791
810};
811
816 return static_cast<dnnl_query_t>(query);
817}
818
820
822
833
835template <>
836struct handle_traits<dnnl_engine_t> {
837 static dnnl_status_t destructor(dnnl_engine_t p) {
838 return dnnl_engine_destroy(p);
839 }
840};
842
844struct engine : public handle<dnnl_engine_t> {
845 friend struct primitive;
846 friend struct reorder;
847
857
858 using handle::handle;
859
862 engine() = default;
863
868 static size_t get_count(kind kind) {
869 return dnnl_engine_get_count(convert_to_c(kind));
870 }
871
877 engine(kind kind, size_t index) {
880 dnnl_engine_create(&engine, convert_to_c(kind), index),
881 "could not create an engine");
882 reset(engine);
883 }
884
885#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
892 engine(kind kind, cl_device_id device, cl_context context) {
895 &engine, convert_to_c(kind), device, context),
896 "could not create an engine");
897 reset(engine);
898 }
899#endif
900
906 dnnl_engine_t c_engine;
910 "could not get an engine from a primitive_desc");
911 reset(c_engine, true);
912 }
913
916 kind get_kind() const {
919 "could not get kind of an engine");
920 return static_cast<engine::kind>(kind);
921 }
922
923#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
926 cl_context get_ocl_context() const {
927 cl_context context = nullptr;
929 "could not get an OpenCL context fron an engine");
930 return context;
931 }
932
935 cl_device_id get_ocl_device() const {
936 cl_device_id device = nullptr;
938 "could not get an OpenCL device fron an engine");
939 return device;
940 }
941#endif
942
948 template <typename primitive_desc>
949 static engine query(const primitive_desc &pd) {
950 return query(pd, dnnl::query::engine);
951 }
952
953private:
955 return static_cast<dnnl_engine_kind_t>(kind);
956 }
957
958 template <typename primitive_desc>
959 static engine query(const primitive_desc &pd, dnnl::query what) {
960 dnnl_engine_t c_engine;
962 dnnl::convert_to_c(what), 0, &c_engine),
963 "could not get an engine from a primitive_desc");
964 return engine(c_engine, true);
965 }
966};
967
973 return static_cast<dnnl_engine_kind_t>(kind);
974}
975
977
985
987template <>
988struct handle_traits<dnnl_stream_t> {
989 static dnnl_status_t destructor(dnnl_stream_t p) {
990 return dnnl_stream_destroy(p);
991 }
992};
993template <>
995 static dnnl_status_t destructor(dnnl_stream_attr_t p) {
996 return dnnl_stream_attr_destroy(p);
997 }
998};
1000
1002struct stream_attr : public handle<dnnl_stream_attr_t> {
1003 using handle::handle;
1004
1006 stream_attr() = default;
1007
1013 dnnl_stream_attr_t attr;
1015 "could not create stream attributes");
1016 reset(attr);
1017 }
1018
1019#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
1027 void set_threadpool(threadpool_iface *threadpool) {
1028 error::wrap_c_api(dnnl_stream_attr_set_threadpool(get(), threadpool),
1029 "could not set stream threadpool attribute");
1030 }
1031
1037 threadpool_iface *get_threadpool() {
1038 threadpool_iface *tp;
1039 error::wrap_c_api(dnnl_stream_attr_get_threadpool(get(), (void **)&tp),
1040 "could not set stream threadpool attribute");
1041 return tp;
1042 }
1043#endif
1044};
1045
1047struct stream : public handle<dnnl_stream_t> {
1048 using handle::handle;
1049
1062
1065 stream() = default;
1066
1074 const stream_attr &attr = stream_attr()) {
1077 static_cast<dnnl_stream_flags_t>(flags),
1078 attr.get(true)),
1079 "could not create a stream");
1080 reset(stream);
1081 }
1082
1083#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
1088 stream(const engine &engine, cl_command_queue queue) {
1091 "could not create a stream");
1092 reset(stream);
1093 }
1094
1097 cl_command_queue get_ocl_command_queue() const {
1098 cl_command_queue queue = nullptr;
1100 "could not get an OpenCL command queue from a stream");
1101 return queue;
1102 }
1103#endif
1104
1109 dnnl_stream_wait(get()), "could not wait on a stream");
1110 return *this;
1111 }
1112};
1113
1114DNNL_DEFINE_BITMASK_OPS(stream::flags)
1115
1116
1117
1118
1181
1188struct memory : public handle<dnnl_memory_t> {
1193 typedef std::vector<dim> dims;
1194
1201 template <typename T>
1202 static void validate_dims(const std::vector<T> &v, int min_size = 0) {
1203 validate_container_size(
1204 v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
1205 }
1206
1224
1241
1282 enum class format_tag {
1288
1291
1296
1307
1322
1343
1345 x = a,
1368
1399
1416
1451
1452 // Opaque blocked formats
1453
1454 Abc16a = dnnl_Abc16a,
1455 ABc16a16b = dnnl_ABc16a16b,
1456 ABc4a4b = dnnl_ABc4a4b,
1457 aBc16b = dnnl_aBc16b,
1458 ABc16b16a = dnnl_ABc16b16a,
1459 Abc4a = dnnl_Abc4a,
1460 aBc4b = dnnl_aBc4b,
1461 ABc4b16a4b = dnnl_ABc4b16a4b,
1462 ABc2b8a4b = dnnl_ABc2b8a4b,
1463 ABc4b4a = dnnl_ABc4b4a,
1464 ABc8a16b2a = dnnl_ABc8a16b2a,
1465 ABc8a8b = dnnl_ABc8a8b,
1466 aBc8b = dnnl_aBc8b,
1467 ABc8b16a2b = dnnl_ABc8b16a2b,
1468 ABc8b8a = dnnl_ABc8b8a,
1469 Abcd16a = dnnl_Abcd16a,
1470 ABcd16a16b = dnnl_ABcd16a16b,
1471 aBcd16b = dnnl_aBcd16b,
1472 ABcd16b16a = dnnl_ABcd16b16a,
1473 aBCd16b16c = dnnl_aBCd16b16c,
1474 aBCd16c16b = dnnl_aBCd16c16b,
1475 Abcd4a = dnnl_Abcd4a,
1476 aBcd4b = dnnl_aBcd4b,
1477 ABcd4b16a4b = dnnl_ABcd4b16a4b,
1478 ABcd2b8a4b = dnnl_ABcd2b8a4b,
1479 ABcd4b4a = dnnl_ABcd4b4a,
1480 ABcd4a4b = dnnl_ABcd4a4b,
1481 aBCd4c16b4c = dnnl_aBCd4c16b4c,
1482 aBCd2c8b4c = dnnl_aBCd2c8b4c,
1483 aBCd4c4b = dnnl_aBCd4c4b,
1484 aBCd4b4c = dnnl_aBCd4b4c,
1485 ABcd8a16b2a = dnnl_ABcd8a16b2a,
1486 ABcd8a8b = dnnl_ABcd8a8b,
1489 ABcd8b16a2b = dnnl_ABcd8b16a2b,
1490 aBCd8b16c2b = dnnl_aBCd8b16c2b,
1493 aBCd8b8c = dnnl_aBCd8b8c,
1494 aBCd8c16b2c = dnnl_aBCd8c16b2c,
1495 aBCd8c8b = dnnl_aBCd8c8b,
1496 Abcde16a = dnnl_Abcde16a,
1497 ABcde16a16b = dnnl_ABcde16a16b,
1498 aBcde16b = dnnl_aBcde16b,
1499 ABcde16b16a = dnnl_ABcde16b16a,
1500 aBCde16b16c = dnnl_aBCde16b16c,
1501 aBCde16c16b = dnnl_aBCde16c16b,
1502 aBCde2c8b4c = dnnl_aBCde2c8b4c,
1503 Abcde4a = dnnl_Abcde4a,
1504 aBcde4b = dnnl_aBcde4b,
1505 ABcde4b4a = dnnl_ABcde4b4a,
1506 ABcde4a4b = dnnl_ABcde4a4b,
1507 aBCde4b4c = dnnl_aBCde4b4c,
1508 aBCde4c16b4c = dnnl_aBCde4c16b4c,
1509 aBCde4c4b = dnnl_aBCde4c4b,
1510 Abcde8a = dnnl_Abcde8a,
1511 ABcde8a8b = dnnl_ABcde8a8b,
1512 aBcde8b = dnnl_aBcde8b,
1513 ABcde8b16a2b = dnnl_ABcde8b16a2b,
1514 ABcde4b16a4b = dnnl_ABcde4b16a4b,
1515 ABcde2b8a4b = dnnl_ABcde2b8a4b,
1516 aBCde8b16c2b = dnnl_aBCde8b16c2b,
1517 ABcde8b8a = dnnl_ABcde8b8a,
1518 aBCde8b8c = dnnl_aBCde8b8c,
1519 ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1520 ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1521 aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1522 aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1523 aBCde8c16b2c = dnnl_aBCde8c16b2c,
1524 aBCde8c8b = dnnl_aBCde8c8b,
1525 aBcdef16b = dnnl_aBcdef16b,
1526 aBCdef16b16c = dnnl_aBCdef16b16c,
1527 aBCdef16c16b = dnnl_aBCdef16c16b,
1528 aBcdef4b = dnnl_aBcdef4b,
1529 aBCdef4c4b = dnnl_aBCdef4c4b,
1530 aBCdef4b4c = dnnl_aBCdef4b4c,
1531 aBCdef8b8c = dnnl_aBCdef8b8c,
1532 aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1533 aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
1534 aBCdef8c8b = dnnl_aBCdef8c8b,
1535 aBdc16b = dnnl_aBdc16b,
1536 aBdc4b = dnnl_aBdc4b,
1537 aBdc8b = dnnl_aBdc8b,
1538 aBdec16b = dnnl_aBdec16b,
1539 aBdec4b = dnnl_aBdec4b,
1540 aBdec8b = dnnl_aBdec8b,
1541 aBdefc16b = dnnl_aBdefc16b,
1542 aCBdef16c16b = dnnl_aCBdef16c16b,
1543 aCBdef16b16c = dnnl_aCBdef16b16c,
1544 aBdefc4b = dnnl_aBdefc4b,
1545 aBdefc8b = dnnl_aBdefc8b,
1546 Acb16a = dnnl_Acb16a,
1547 Acb4a = dnnl_Acb4a,
1548 Acb8a = dnnl_Acb8a,
1549 aCBd16b16c = dnnl_aCBd16b16c,
1550 aCBd16c16b = dnnl_aCBd16c16b,
1551 aCBde16b16c = dnnl_aCBde16b16c,
1552 aCBde16c16b = dnnl_aCBde16c16b,
1553 Acdb16a = dnnl_Acdb16a,
1554 Acdb4a = dnnl_Acdb4a,
1555 Acdb8a = dnnl_Acdb8a,
1556 Acdeb16a = dnnl_Acdeb16a,
1557 Acdeb4a = dnnl_Acdeb4a,
1558 Acdeb8a = dnnl_Acdeb8a,
1559 BAc16a16b = dnnl_BAc16a16b,
1560 BAc16b16a = dnnl_BAc16b16a,
1561 BAcd16a16b = dnnl_BAcd16a16b,
1562 BAcd16b16a = dnnl_BAcd16b16a,
1563 ABcd32a32b = dnnl_ABcd32a32b,
1564 BAcde16b16a = dnnl_BAcde16b16a,
1565 BAcde16a16b = dnnl_BAcde16a16b,
1566 aBdec32b = dnnl_aBdec32b,
1567 Abcdef16a = dnnl_Abcdef16a,
1568 Acdb32a = dnnl_Acdb32a,
1569 aBCd2b4c2b = dnnl_aBCd2b4c2b,
1570 aBCde2b4c2b = dnnl_aBCde2b4c2b,
1571 aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
1572 aBCd2c4b2c = dnnl_aBCd2c4b2c,
1573 aBCde2c4b2c = dnnl_aBCde2c4b2c,
1574 aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
1575 aBCd4b8c2b = dnnl_aBCd4b8c2b,
1576 aBCde4b8c2b = dnnl_aBCde4b8c2b,
1577 aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
1578 aBCd4c8b2c = dnnl_aBCd4c8b2c,
1579 aBCde4c8b2c = dnnl_aBCde4c8b2c,
1580 aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
1581
1582 format_tag_last = dnnl_format_tag_last,
1583
1584 nCdhw16c = dnnl_nCdhw16c,
1585 nCdhw4c = dnnl_nCdhw4c,
1586 nCdhw8c = dnnl_nCdhw8c,
1587 nChw16c = dnnl_nChw16c,
1588 nChw4c = dnnl_nChw4c,
1589 nChw8c = dnnl_nChw8c,
1590 nCw16c = dnnl_nCw16c,
1591 nCw4c = dnnl_nCw4c,
1592 nCw8c = dnnl_nCw8c,
1593 NCw16n16c = dnnl_NCw16n16c,
1594 NChw16n16c = dnnl_NChw16n16c,
1595 NCdhw16n16c = dnnl_NCdhw16n16c,
1596 NChw32n32c = dnnl_NChw32n32c,
1597 IOhw16i16o = dnnl_IOhw16i16o,
1598 Ohwi32o = dnnl_Ohwi32o,
1599 IOdhw16i16o = dnnl_IOdhw16i16o,
1600 gIOhw16i16o = dnnl_gIOhw16i16o,
1601 gOhwi32o = dnnl_gOhwi32o,
1602 Goidhw16g = dnnl_Goidhw16g,
1603 IOw16o16i = dnnl_IOw16o16i,
1604 OIw16i16o = dnnl_OIw16i16o,
1605 IOw16i16o = dnnl_IOw16i16o,
1606 gIOw16i16o = dnnl_gIOw16i16o,
1607 OIw16o16i = dnnl_OIw16o16i,
1608 Oiw16o = dnnl_Oiw16o,
1609 OIw4i16o4i = dnnl_OIw4i16o4i,
1610 OIw2i8o4i = dnnl_OIw2i8o4i,
1611 OIw4i4o = dnnl_OIw4i4o,
1612 OIw4o4i = dnnl_OIw4o4i,
1613 Oiw4o = dnnl_Oiw4o,
1614 OIw8i16o2i = dnnl_OIw8i16o2i,
1615 OIw8i8o = dnnl_OIw8i8o,
1616 OIw8o16i2o = dnnl_OIw8o16i2o,
1617 OIw8o8i = dnnl_OIw8o8i,
1618 Owi16o = dnnl_Owi16o,
1619 OwI16o2i = dnnl_OwI16o2i,
1620 Owi4o = dnnl_Owi4o,
1621 Owi8o = dnnl_Owi8o,
1622 IOhw16o16i = dnnl_IOhw16o16i,
1623 Ohwi16o = dnnl_Ohwi16o,
1624 OhwI16o2i = dnnl_OhwI16o2i,
1625 Ohwi4o = dnnl_Ohwi4o,
1626 Ohwi8o = dnnl_Ohwi8o,
1627 OIhw16i16o = dnnl_OIhw16i16o,
1628 OIhw16o16i = dnnl_OIhw16o16i,
1629 Oihw16o = dnnl_Oihw16o,
1630 OIhw4i16o4i = dnnl_OIhw4i16o4i,
1631 OIhw4i4o = dnnl_OIhw4i4o,
1632 OIhw4o4i = dnnl_OIhw4o4i,
1633 Oihw4o = dnnl_Oihw4o,
1634 OIhw8i16o2i = dnnl_OIhw8i16o2i,
1635 OIhw8i8o = dnnl_OIhw8i8o,
1636 OIhw8o16i2o = dnnl_OIhw8o16i2o,
1637 OIhw8o8i = dnnl_OIhw8o8i,
1638 OIhw2i8o4i = dnnl_OIhw2i8o4i,
1639 IOdhw16o16i = dnnl_IOdhw16o16i,
1640 Odhwi16o = dnnl_Odhwi16o,
1641 OdhwI16o2i = dnnl_OdhwI16o2i,
1642 Odhwi4o = dnnl_Odhwi4o,
1643 Odhwi8o = dnnl_Odhwi8o,
1644 OIdhw16i16o = dnnl_OIdhw16i16o,
1645 OIdhw16o16i = dnnl_OIdhw16o16i,
1646 Oidhw16o = dnnl_Oidhw16o,
1647 OIdhw4i4o = dnnl_OIdhw4i4o,
1648 OIdhw4o4i = dnnl_OIdhw4o4i,
1649 Oidhw4o = dnnl_Oidhw4o,
1650 OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1651 OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
1652 OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
1653 OIdhw8i8o = dnnl_OIdhw8i8o,
1654 OIdhw8o8i = dnnl_OIdhw8o8i,
1655 gIOw16o16i = dnnl_gIOw16o16i,
1656 gOIw16i16o = dnnl_gOIw16i16o,
1657 gOIw16o16i = dnnl_gOIw16o16i,
1658 gOiw16o = dnnl_gOiw16o,
1659 gOIw4i16o4i = dnnl_gOIw4i16o4i,
1660 gOIw2i8o4i = dnnl_gOIw2i8o4i,
1661 gOIw4i4o = dnnl_gOIw4i4o,
1662 gOIw4o4i = dnnl_gOIw4o4i,
1663 gOiw4o = dnnl_gOiw4o,
1664 gOIw8i16o2i = dnnl_gOIw8i16o2i,
1665 gOIw8i8o = dnnl_gOIw8i8o,
1666 gOIw8o16i2o = dnnl_gOIw8o16i2o,
1667 gOIw8o8i = dnnl_gOIw8o8i,
1668 gOwi16o = dnnl_gOwi16o,
1669 gOwI16o2i = dnnl_gOwI16o2i,
1670 gOwi4o = dnnl_gOwi4o,
1671 gOwi8o = dnnl_gOwi8o,
1672 Goiw8g = dnnl_Goiw8g,
1673 Goiw16g = dnnl_Goiw16g,
1674 gIOhw16o16i = dnnl_gIOhw16o16i,
1675 gOhwi16o = dnnl_gOhwi16o,
1676 gOhwI16o2i = dnnl_gOhwI16o2i,
1677 gOhwi4o = dnnl_gOhwi4o,
1678 gOhwi8o = dnnl_gOhwi8o,
1679 Goihw16g = dnnl_Goihw16g,
1680 gOIhw16i16o = dnnl_gOIhw16i16o,
1681 gOIhw16o16i = dnnl_gOIhw16o16i,
1682 gOihw16o = dnnl_gOihw16o,
1683 gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1684 gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1685 gOIhw4i4o = dnnl_gOIhw4i4o,
1686 gOIhw4o4i = dnnl_gOIhw4o4i,
1687 gOihw4o = dnnl_gOihw4o,
1688 Goihw8g = dnnl_Goihw8g,
1689 gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1690 gOIhw8i8o = dnnl_gOIhw8i8o,
1691 gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1692 OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1693 OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1694 gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1695 gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1696 gOIhw8o8i = dnnl_gOIhw8o8i,
1697 gIOdhw16i16o = dnnl_gIOdhw16i16o,
1698 gIOdhw16o16i = dnnl_gIOdhw16o16i,
1699 gOdhwi16o = dnnl_gOdhwi16o,
1700 gOdhwI16o2i = dnnl_gOdhwI16o2i,
1701 gOdhwi4o = dnnl_gOdhwi4o,
1702 gOdhwi8o = dnnl_gOdhwi8o,
1703 gOIdhw16i16o = dnnl_gOIdhw16i16o,
1704 gOIdhw16o16i = dnnl_gOIdhw16o16i,
1705 gOidhw16o = dnnl_gOidhw16o,
1706 gOIdhw4i4o = dnnl_gOIdhw4i4o,
1707 gOIdhw4o4i = dnnl_gOIdhw4o4i,
1708 gOidhw4o = dnnl_gOidhw4o,
1709 gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1710 gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
1711 gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
1712 gOIdhw8i8o = dnnl_gOIdhw8i8o,
1713 gOIdhw8o8i = dnnl_gOIdhw8o8i,
1714 gOIw2i4o2i = dnnl_gOIw2i4o2i,
1715 gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
1716 gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
1717 gOIw2o4i2o = dnnl_gOIw2o4i2o,
1718 gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
1719 gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
1720 gOIw4i8o2i = dnnl_gOIw4i8o2i,
1721 gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
1722 gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
1723 gOIw4o8i2o = dnnl_gOIw4o8i2o,
1724 gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
1725 gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
1726 };
1727
1729 struct desc {
1730 friend struct memory;
1733
1736 desc() : data() {}
1737
1754 format_tag format_tag, bool allow_empty = false)
1755 : data() {
1758 (int)dims.size(), dims.data(), convert_to_c(data_type),
1759 convert_to_c(format_tag));
1760 if (!allow_empty)
1762 "could not construct a memory descriptor using a "
1763 "format tag");
1764 }
1765
1782 const memory::dims &strides, bool allow_empty = false)
1783 : data() {
1785 if (!strides.empty()) validate_dims(strides, (int)dims.size());
1787 (int)dims.size(), dims.data(), convert_to_c(data_type),
1788 strides.empty() ? nullptr : &strides[0]);
1789 if (!allow_empty)
1791 "could not construct a memory descriptor using "
1792 "strides");
1793 }
1794
1799
1802 //
1812 const memory::dims &offsets, bool allow_empty = false) const {
1813 validate_dims(dims, data.ndims);
1814 validate_dims(offsets, data.ndims);
1817 &sub_md, &data, dims.data(), offsets.data());
1818 if (!allow_empty)
1819 error::wrap_c_api(status, "could not construct a sub-memory");
1820 return desc(sub_md);
1821 }
1822
1867 desc reshape(const memory::dims &dims, bool allow_empty = false) const {
1868 if (data.ndims) validate_dims(dims, 1);
1871 &out_md, &data, (int)dims.size(), dims.data());
1872 if (!allow_empty)
1874 status, "could not reshape a memory descriptor");
1875 return desc(out_md);
1876 }
1877
1914 desc permute_axes(const std::vector<int> &permutation,
1915 bool allow_empty = false) const {
1916 validate_dims(permutation, data.ndims);
1919 &out_md, &data, permutation.data());
1920 if (!allow_empty)
1922 "could not permute axes of a memory descriptor");
1923 return desc(out_md);
1924 }
1925
1931 return memory::dims(data.dims, data.dims + data.ndims);
1932 }
1933
1937 return static_cast<memory::data_type>(data.data_type);
1938 }
1939
1944 size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
1945
1949 bool is_zero() const { return data.ndims == 0; }
1950
1955 bool operator==(const desc &other) const {
1956 return dnnl_memory_desc_equal(&data, &other.data) != 0;
1957 }
1958
1963 bool operator!=(const desc &other) const { return !operator==(other); }
1964 };
1965
1966 // Default constructor.
1967 //
1968 // Constructs an empty memory object, which can be used to indicate absence
1969 // of a parameter.
1970 memory() = default;
1971
1992 memory(const desc &md, const engine &engine, void *handle) {
1993 dnnl_memory_t result;
1995 dnnl_memory_create(&result, &md.data, engine.get(), handle),
1996 "could not create a memory object");
1997 reset(result);
1998 }
1999
2006 memory(const desc &md, const engine &engine)
2007 : memory(md, engine, DNNL_MEMORY_ALLOCATE) {}
2008
2010 desc get_desc() const {
2011 const dnnl_memory_desc_t *cdesc;
2013 "could not get a memory descriptor from a memory object");
2014 return desc(*cdesc);
2015 }
2016
2019 dnnl_engine_t c_engine;
2021 "could not get an engine from a memory object");
2022 return engine(c_engine, true);
2023 }
2024
2028 void *get_data_handle() const {
2029 void *handle;
2031 "could not get a native handle from a memory object");
2032 return handle;
2033 }
2034
2061 void set_data_handle(void *handle, const stream &stream) const {
2064 "could not set native handle of a memory object");
2065 }
2066
2075 void set_data_handle(void *handle) const {
2078 "could not set native handle of a memory object");
2079 }
2080
2101 template <typename T = void>
2102 T *map_data() const {
2103 void *mapped_ptr;
2105 "could not map memory object data");
2106 return static_cast<T *>(mapped_ptr);
2107 }
2108
2118 void unmap_data(void *mapped_ptr) const {
2120 "could not unmap memory object data");
2121 }
2122
2123#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
2125 cl_mem get_ocl_mem_object() const {
2126 cl_mem mem_object;
2128 "could not get OpenCL buffer object from a memory object");
2129 return mem_object;
2130 }
2131
2139 void set_ocl_mem_object(cl_mem mem_object) {
2141 "could not set OpenCL buffer object from a memory object");
2142 }
2143#endif
2144
2145 static dnnl_data_type_t convert_to_c(data_type data_type) {
2146 return static_cast<dnnl_data_type_t>(data_type);
2147 }
2148 static dnnl_format_tag_t convert_to_c(format_tag format) {
2149 return static_cast<dnnl_format_tag_t>(format);
2150 }
2151};
2152
2153inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
2154 return a == memory::convert_to_c(b);
2155}
2156inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
2157 return !(a == b);
2158}
2159inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
2160 return b == a;
2161}
2162inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
2163 return !(a == b);
2164}
2165
2166inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
2167 return a == memory::convert_to_c(b);
2168}
2169inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
2170 return !(a == b);
2171}
2172inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
2173 return b == a;
2174}
2175inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
2176 return !(a == b);
2177}
2178
2180
2188
2190template <>
2192 static dnnl_status_t destructor(dnnl_post_ops_t p) {
2193 return dnnl_post_ops_destroy(p);
2194 }
2195};
2197
2205struct post_ops : public handle<dnnl_post_ops_t> {
2207
2210 dnnl_post_ops_t result;
2212 dnnl_post_ops_create(&result), "could not create post-ops");
2213 reset(result);
2214 }
2215
2217 int len() const { return dnnl_post_ops_len(get()); }
2218
2222 primitive::kind kind(int index) const {
2224 "post-ops index is out of range");
2225 return static_cast<primitive::kind>(
2226 dnnl_post_ops_get_kind(get(), index));
2227 }
2228
2251 void append_sum(float scale = 1.) {
2253 "could not append a sum post-op");
2254 }
2255
2260 void get_params_sum(int index, float &scale) const {
2262 "could not get parameters of a sum post-op");
2263 }
2264
2281 float scale, algorithm algorithm, float alpha, float beta) {
2283 convert_to_c(algorithm), alpha, beta),
2284 "could not append an elementwise post-op");
2285 }
2286
2294 void get_params_eltwise(int index, float &scale, algorithm &algorithm,
2295 float &alpha, float &beta) const {
2296 dnnl_alg_kind_t c_alg;
2298 get(), index, &scale, &c_alg, &alpha, &beta),
2299 "could not get parameters of an elementwise post-op");
2300 algorithm = static_cast<dnnl::algorithm>(c_alg);
2301 }
2302
2331 void append_dw_k3s1p1(memory::data_type weights_data_type,
2332 memory::data_type bias_data_type, memory::data_type dst_data_type,
2333 int mask, const std::vector<float> &scales) {
2334
2336 memory::convert_to_c(weights_data_type),
2337 memory::convert_to_c(bias_data_type),
2338 memory::convert_to_c(dst_data_type),
2339 scales.size(), mask, &scales[0]),
2340 "could not append depthwise post-op");
2341 }
2342
2357 void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type,
2358 memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2359 int &mask, std::vector<float> &scales) const {
2360
2361 dnnl_data_type_t c_weights_data_type;
2362 dnnl_data_type_t c_bias_data_type;
2363 dnnl_data_type_t c_dst_data_type;
2364 dnnl_dim_t count;
2365 int c_mask;
2366 const float *c_scales;
2368 &c_weights_data_type, &c_bias_data_type,
2369 &c_dst_data_type, &count, &c_mask, &c_scales),
2370 "could not get parameters of depthwise post-op");
2371
2372 weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2373 bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2374 dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2375 scales.resize(count);
2376
2377 mask = c_mask;
2378 for (dnnl_dim_t c = 0; c < count; ++c)
2379 scales[c] = c_scales[c];
2380 return;
2381 }
2382
2416 void append_dw_k3s2p1(memory::data_type weights_data_type,
2417 memory::data_type bias_data_type, memory::data_type dst_data_type,
2418 int mask, const std::vector<float> &scales) {
2419
2421 memory::convert_to_c(weights_data_type),
2422 memory::convert_to_c(bias_data_type),
2423 memory::convert_to_c(dst_data_type),
2424 scales.size(), mask, &scales[0]),
2425 "could not append depthwise post-op");
2426 }
2427
2442 void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type,
2443 memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2444 int &mask, std::vector<float> &scales) const {
2445
2446 dnnl_data_type_t c_weights_data_type;
2447 dnnl_data_type_t c_bias_data_type;
2448 dnnl_data_type_t c_dst_data_type;
2449 dnnl_dim_t count;
2450 int c_mask;
2451 const float *c_scales;
2453 &c_weights_data_type, &c_bias_data_type,
2454 &c_dst_data_type, &count, &c_mask, &c_scales),
2455 "could not get parameters of depthwise post-op");
2456
2457 weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2458 bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2459 dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2460 scales.resize(count);
2461
2462 mask = c_mask;
2463 for (dnnl_dim_t c = 0; c < count; ++c)
2464 scales[c] = c_scales[c];
2465 return;
2466 }
2467};
2468
2470template <>
2471struct handle_traits<dnnl_primitive_attr_t> {
2472 static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
2474 }
2475};
2477
2481struct primitive_attr : public handle<dnnl_primitive_attr_t> {
2483
2486 dnnl_primitive_attr_t result;
2488 "could not create primitive attribute");
2489 reset(result);
2490 }
2491
2499
2505 "could not get scratchpad mode primitive attribute");
2506 return scratchpad_mode(result);
2507 }
2508
2514 get(), dnnl::convert_to_c(mode)),
2515 "could not set scratchpad mode primitive attribute");
2516 }
2517
2527 void get_output_scales(int &mask, std::vector<float> &scales) const {
2528 dnnl_dim_t count;
2529 int c_mask;
2530 const float *c_scales;
2532 get(), &count, &c_mask, &c_scales),
2533 "could not get output scales primitive attribute");
2534 scales.resize(count);
2535
2536 mask = c_mask;
2537 for (dnnl_dim_t c = 0; c < count; ++c)
2538 scales[c] = c_scales[c];
2539 }
2540
2583 void set_output_scales(int mask, const std::vector<float> &scales) {
2586 get(), (dnnl_dim_t)scales.size(), mask, scales.data()),
2587 "could not set output scales primitive attribute");
2588 }
2589
2601 void get_scales(int arg, int &mask, std::vector<float> &scales) const {
2602 dnnl_dim_t count;
2603 int c_mask;
2604 const float *c_scales;
2606 get(), arg, &count, &c_mask, &c_scales),
2607 "could not get scales primitive attributes");
2608 scales.resize(count);
2609
2610 mask = c_mask;
2611 for (dnnl_dim_t c = 0; c < count; ++c)
2612 scales[c] = c_scales[c];
2613 }
2614
2631 void set_scales(int arg, int mask, const std::vector<float> &scales) {
2634 (dnnl_dim_t)scales.size(), mask, scales.data()),
2635 "could not set scales primitive attribute");
2636 }
2637
2649 int arg, int &mask, std::vector<int32_t> &zero_points) const {
2650 dnnl_dim_t count;
2651 int c_mask;
2652 const int32_t *c_zero_points;
2654 get(), arg, &count, &c_mask, &c_zero_points),
2655 "could not get zero points primitive attribute");
2656 zero_points.resize(count);
2657
2658 mask = c_mask;
2659 for (dnnl_dim_t c = 0; c < count; ++c)
2660 zero_points[c] = c_zero_points[c];
2661 }
2662
2684 int arg, int mask, const std::vector<int32_t> &zero_points) {
2686 (dnnl_dim_t)zero_points.size(), mask,
2687 zero_points.data()),
2688 "could not set zero points primitive attribute");
2689 }
2690
2694 const post_ops get_post_ops() const {
2695 post_ops result;
2696 const_dnnl_post_ops_t c_result;
2698 "could not get post-ops primitive attribute");
2699 result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
2700 return result;
2701 }
2702
2711 void set_post_ops(const post_ops ops) {
2713 "could not set post-ops primitive attribute");
2714 }
2715
2749 void set_rnn_data_qparams(float scale, float shift) {
2752 "could not get RNN data quantization parameters primitive "
2753 "attribute");
2754 }
2755
2782 void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
2784 (int)scales.size(), mask, scales.data()),
2785 "could not get RNN weights quantization parameters primitive "
2786 "attribute");
2787 }
2788};
2789
2791
2794
2796struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
2798
2801
2804 engine get_engine() const { return engine::query(*this); }
2805
2808 const char *impl_info_str() const {
2809 const char *res;
2811 get(), dnnl_query_impl_info_str, 0, &res),
2812 "could not retrieve implementation info string from a "
2813 "primitive descriptor");
2814 return res;
2815 }
2816
2821 memory::dim res;
2823 get(), dnnl::convert_to_c(what), 0, &res);
2824 return status == dnnl_success ? res : 0;
2825 }
2826
2841 memory::desc query_md(query what, int idx = 0) const {
2842 std::vector<query> valid_q {query::src_md, query::diff_src_md,
2846 if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
2847 [=](query q) { return what == q; }))
2848 DNNL_THROW_ERROR(dnnl_invalid_arguments,
2849 "memory descriptor query is invalid");
2850
2852 get(), dnnl::convert_to_c(what), idx);
2853 return cdesc ? memory::desc(*cdesc) : memory::desc();
2854 }
2855
2861 memory::desc src_desc(int idx) const {
2862 return query_md(query::src_md, idx);
2863 }
2864
2870 memory::desc dst_desc(int idx) const {
2871 return query_md(query::dst_md, idx);
2872 }
2873
2880 return query_md(query::weights_md, idx);
2881 }
2882
2889 return query_md(query::diff_src_md, idx);
2890 }
2891
2898 return query_md(query::diff_dst_md, idx);
2899 }
2900
2907 return query_md(query::diff_weights_md, idx);
2908 }
2909
2910 // Separate versions without the index argument for documentation
2911 // purposes.
2912
2917 memory::desc src_desc() const { return src_desc(0); }
2918
2923 memory::desc dst_desc() const { return dst_desc(0); }
2924
2930
2936
2942
2948
2954 return query_md(query::workspace_md, 0);
2955 }
2956
2965
2969 dnnl_engine_t c_engine;
2972 0, &c_engine),
2973 "could not retrieve scratchpad engine from a primitive "
2974 "descriptor");
2975 return engine(c_engine, true);
2976 }
2977
2981 const_dnnl_primitive_attr_t const_c_attr;
2983 "could not get attributes from a primitive descriptor");
2984 dnnl_primitive_attr_t c_attr;
2985 error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
2986 "could not clone primitive attributes");
2987 return primitive_attr(c_attr);
2988 }
2989
2995 dnnl_query_primitive_kind, 0, (void *)&kind),
2996 "could not get primitive kind from a primitive descriptor");
2997 return static_cast<dnnl::primitive::kind>(kind);
2998 }
2999
3000protected:
3005 dnnl_primitive_desc_t new_pd;
3007 "could not clone a primitive descriptor");
3008 reset(new_pd);
3009 }
3010
3027
3042
3057 dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
3058 dnnl::prop_kind prop_kind2) {
3059 // It is OK to pass an empty primitive descriptor
3060 if (pd == nullptr) return;
3061
3062 dnnl_status_t rc;
3063
3064 dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
3065 dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
3066 dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
3067
3068 // Check that primitive kind matches
3069 dnnl_primitive_kind_t pd_kind;
3071 pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
3073 rc, "could not get primitive kind from a primitive descriptor");
3074 if (pd_kind != c_prim_kind)
3075 DNNL_THROW_ERROR(dnnl_invalid_arguments,
3076 "primitive descriptor operation kind mismatch");
3077
3078 // Check that propagation kind matches
3079 dnnl_prop_kind_t pd_prop_kind;
3081 pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
3082
3083 // Something went wrong
3084 if (rc != dnnl_success && rc != dnnl_unimplemented)
3085 DNNL_THROW_ERROR(dnnl_invalid_arguments,
3086 "could not get propagation kind from the primitive "
3087 "descriptor");
3088
3089 // Everything is fine
3090 if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
3091 || (rc == dnnl_success
3092 && (pd_prop_kind == c_prop_kind1
3093 || pd_prop_kind == c_prop_kind2))) {
3094 reset_with_clone(pd);
3095 return;
3096 }
3097
3098 // We could get the propagation kind but there is a mismatch
3099 DNNL_THROW_ERROR(dnnl_invalid_arguments,
3100 "primitive descriptor propagation kind mismatch");
3101 }
3102
3103 using base = primitive_desc_base;
3104};
3105
3107
3116
3118struct reorder : public primitive {
3122
3124 primitive_desc() = default;
3125
3141 primitive_desc(const engine &src_engine, const memory::desc &src_md,
3142 const engine &dst_engine, const memory::desc &dst_md,
3143 const primitive_attr &attr = primitive_attr()) {
3144 dnnl_primitive_desc_t result;
3147 src_engine.get(), &dst_md.data, dst_engine.get(),
3148 attr.get()),
3149 "could not create a primitive descriptor for a reorder "
3150 "primitive");
3151 reset(result);
3152 }
3153
3161 primitive_desc(const memory &src, const memory &dst,
3162 const primitive_attr &attr = primitive_attr()) {
3163 dnnl_primitive_desc_t result;
3164 auto src_md = src.get_desc();
3165 auto dst_md = dst.get_desc();
3168 src.get_engine().get(), &dst_md.data,
3169 dst.get_engine().get(), attr.get()),
3170 "could not create a primitive descriptor for a reorder "
3171 "primitive");
3172 reset(result);
3173 }
3174
3181
3187
3193
3195 memory::desc src_desc() const { return base::src_desc(0); }
3196
3198 memory::desc dst_desc() const { return base::dst_desc(0); }
3199 };
3200
3202 reorder() = default;
3203
3206 reorder(const primitive_desc &pd) : primitive(pd.get()) {}
3207
3215 reorder(const memory &src, const memory &dst,
3216 const primitive_attr &attr = primitive_attr())
3217 : primitive(primitive_desc(src, dst, attr).get()) {}
3218
3219 using primitive::execute;
3220
3227 void execute(const stream &stream, memory &src, memory &dst) const {
3228 primitive::execute(stream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
3229 }
3230};
3231
3233
3241
3243inline std::vector<dnnl_memory_desc_t> convert_to_c(
3244 const std::vector<memory::desc> &mems) {
3245 std::vector<dnnl_memory_desc_t> c_mems;
3246 c_mems.reserve(mems.size());
3247 for (const auto &s : mems)
3248 c_mems.push_back(s.data);
3249 return c_mems;
3250}
3252
3254struct concat : public primitive {
3258
3260 primitive_desc() = default;
3261
3281 primitive_desc(const memory::desc &dst, int concat_dimension,
3282 const std::vector<memory::desc> &srcs, const engine &engine,
3283 const primitive_attr &attr = primitive_attr()) {
3284 auto c_srcs = convert_to_c(srcs);
3285
3286 dnnl_primitive_desc_t result;
3289 (int)c_srcs.size(), concat_dimension, c_srcs.data(),
3290 attr.get(), engine.get()),
3291 "could not create a primitive descriptor for a concat "
3292 "primitive");
3293 reset(result);
3294 }
3295
3308 primitive_desc(int concat_dimension,
3309 const std::vector<memory::desc> &srcs, const engine &engine,
3310 const primitive_attr &attr = primitive_attr()) {
3311 auto c_api_srcs = convert_to_c(srcs);
3312
3313 dnnl_primitive_desc_t result;
3315 dnnl_concat_primitive_desc_create(&result, nullptr,
3316 (int)c_api_srcs.size(), concat_dimension,
3317 c_api_srcs.data(), attr.get(), engine.get()),
3318 "could not create a primitive descriptor for a concat "
3319 "primitive");
3320 reset(result);
3321 }
3322
3329
3331 memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3332
3334 memory::desc dst_desc() const { return base::dst_desc(0); }
3335 };
3336
3338 concat() = default;
3339
3342 concat(const primitive_desc &pd) : primitive(pd.get()) {}
3343};
3344
3346
3354
3356struct sum : public primitive {
3360
3362 primitive_desc() = default;
3363
3382 const std::vector<float> &scales,
3383 const std::vector<memory::desc> &srcs, const engine &engine,
3384 const primitive_attr &attr = primitive_attr()) {
3385 validate_container_size(scales,
3386 "counts of scales and sources are not equal",
3387 (int)srcs.size(), (int)srcs.size());
3388
3389 auto c_api_srcs = convert_to_c(srcs);
3390
3391 dnnl_primitive_desc_t result;
3394 (int)c_api_srcs.size(), scales.data(),
3395 c_api_srcs.data(), attr.get(), engine.get()),
3396 "could not create a primitive descriptor for a sum "
3397 "primitive");
3398 reset(result);
3399 }
3400
3411 primitive_desc(const std::vector<float> &scales,
3412 const std::vector<memory::desc> &srcs, const engine &engine,
3413 const primitive_attr &attr = primitive_attr()) {
3414 validate_container_size(scales,
3415 "counts of scales and sources are not equal",
3416 (int)srcs.size(), (int)srcs.size());
3417
3418 auto c_api_srcs = convert_to_c(srcs);
3419 dnnl_primitive_desc_t result;
3421 dnnl_sum_primitive_desc_create(&result, nullptr,
3422 (int)c_api_srcs.size(), scales.data(),
3423 c_api_srcs.data(), attr.get(), engine.get()),
3424 "could not create a primitive descriptor for a sum "
3425 "primitive");
3426 reset(result);
3427 }
3428
3435
3437 memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3438
3440 memory::desc dst_desc() const { return base::dst_desc(0); }
3441 };
3442
3444 sum() = default;
3445
3448 sum(const primitive_desc &pd) : primitive(pd.get()) {}
3449};
3450
3452
3455
3458struct primitive_desc : public primitive_desc_base {
3460
3461 primitive_desc() = default;
3462
3486 const engine &engine, const_dnnl_primitive_desc_t hint_fwd_pd,
3487 bool allow_empty = false)
3488 : allow_empty_(allow_empty) {
3489 dnnl_primitive_desc_iterator_t iterator = nullptr;
3491 desc, attr ? attr->get() : nullptr, engine.get(), hint_fwd_pd);
3492 if (!allow_empty)
3494 status, "could not create a primitive descriptor iterator");
3495 pd_iterator.reset(iterator);
3496 fetch_impl();
3497 }
3498
3503 bool next_impl() {
3505 = dnnl_primitive_desc_iterator_next(pd_iterator.get());
3506 if (status == dnnl_iterator_ends) return false;
3508 status, "could not advance a primitive descriptor iterator");
3509 fetch_impl();
3510 return true;
3511 }
3512
3513private:
3514 bool allow_empty_ = false;
3516 void fetch_impl() {
3518 pd_iterator.get(allow_empty_));
3519 error::wrap_c_api(pd != nullptr || allow_empty_ ? dnnl_success
3521 "could not fetch a primitive descriptor from a primitive "
3522 "descriptor iterator");
3523 reset(pd);
3524 }
3525};
3526
3528
3538
3542 struct desc {
3544
3578 const memory::desc &src_desc, const memory::desc &weights_desc,
3579 const memory::desc &bias_desc, const memory::desc &dst_desc,
3580 const memory::dims &strides, const memory::dims &padding_l,
3581 const memory::dims &padding_r) {
3582 memory::validate_dims(strides, src_desc.data.ndims - 2);
3583 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
3584 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
3588 convert_to_c(algorithm), &src_desc.data,
3589 &weights_desc.data, &bias_desc.data, &dst_desc.data,
3590 &strides[0], &padding_l[0], &padding_r[0]),
3591 "could not create a descriptor for a convolution forward "
3592 "propagation primitive");
3593 }
3594
3625 const memory::desc &src_desc, const memory::desc &weights_desc,
3626 const memory::desc &dst_desc, const memory::dims &strides,
3627 const memory::dims &padding_l, const memory::dims &padding_r) {
3628 memory::validate_dims(strides, src_desc.data.ndims - 2);
3629 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
3630 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
3634 convert_to_c(algorithm), &src_desc.data,
3635 &weights_desc.data, nullptr, &dst_desc.data,
3636 &strides[0], &padding_l[0], &padding_r[0]),
3637 "could not create a descriptor for a convolution forward "
3638 "propagation primitive");
3639 }
3640
3676 const memory::desc &src_desc, const memory::desc &weights_desc,
3677 const memory::desc &bias_desc, const memory::desc &dst_desc,
3678 const memory::dims &strides, const memory::dims &dilates,
3679 const memory::dims &padding_l, const memory::dims &padding_r) {
3680 memory::validate_dims(strides, src_desc.data.ndims - 2);
3681 memory::validate_dims(dilates, src_desc.data.ndims - 2);
3682 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
3683 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
3686 convert_to_c(algorithm), &src_desc.data,
3687 &weights_desc.data, &bias_desc.data,
3688 &dst_desc.data, &strides[0], &dilates[0],
3689 &padding_l[0], &padding_r[0]),
3690 "could not create a descriptor for a dilated convolution "
3691 "forward propagation primitive");
3692 }
3693
3726 const memory::desc &src_desc, const memory::desc &weights_desc,
3727 const memory::desc &dst_desc, const memory::dims &strides,
3728 const memory::dims &dilates, const memory::dims &padding_l,
3729 const memory::dims &padding_r) {
3730 memory::validate_dims(strides, src_desc.data.ndims - 2);
3731 memory::validate_dims(dilates, src_desc.data.ndims - 2);
3732 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
3733 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
3736 convert_to_c(algorithm), &src_desc.data,
3737 &weights_desc.data, nullptr,
3738 &dst_desc.data, &strides[0], &dilates[0],
3739 &padding_l[0], &padding_r[0]),
3740 "could not create a descriptor for a dilated convolution "
3741 "forward propagation primitive");
3742 }
3743 };
3744
3748 primitive_desc() = default;
3749
3761 bool allow_empty = false)
3763 &desc.data, nullptr, engine, nullptr, allow_empty) {}
3764
3777 const engine &engine, bool allow_empty = false)
3779 &desc.data, &attr, engine, nullptr, allow_empty) {}
3780
3791
3793 memory::desc src_desc() const { return base::src_desc(0); }
3794
3797
3799 memory::desc dst_desc() const { return base::dst_desc(0); }
3800
3806 };
3807
3810
3815};
3816
3819
3821 struct desc {
3823
3850 desc(algorithm algorithm, const memory::desc &diff_src_desc,
3851 const memory::desc &weights_desc,
3852 const memory::desc &diff_dst_desc, const memory::dims &strides,
3853 const memory::dims &padding_l, const memory::dims &padding_r) {
3854 memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
3855 memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
3856 memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
3859 convert_to_c(algorithm), &diff_src_desc.data,
3860 &weights_desc.data, &diff_dst_desc.data,
3861 &strides[0], &padding_l[0], &padding_r[0]),
3862 "could not create a descriptor for a convolution backward "
3863 "propagation primitive");
3864 }
3865
3894 desc(algorithm algorithm, const memory::desc &diff_src_desc,
3895 const memory::desc &weights_desc,
3896 const memory::desc &diff_dst_desc, const memory::dims &strides,
3897 const memory::dims &dilates, const memory::dims &padding_l,
3898 const memory::dims &padding_r) {
3899 memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
3900 memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
3901 memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
3902 memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
3905 convert_to_c(algorithm), &diff_src_desc.data,
3906 &weights_desc.data, &diff_dst_desc.data,
3907 &strides[0], &dilates[0], &padding_l[0],
3908 &padding_r[0]),
3909 "could not create a descriptor for a dilated convolution "
3910 "backward propagation primitive");
3911 }
3912 };
3913
3917 primitive_desc() = default;
3918
3933 const convolution_forward::primitive_desc &hint_fwd_pd,
3934 bool allow_empty = false)
3935 : dnnl::primitive_desc(&desc.data, nullptr, engine,
3936 hint_fwd_pd.get(), allow_empty) {}
3937
3953 const engine &engine,
3954 const convolution_forward::primitive_desc &hint_fwd_pd,
3955 bool allow_empty = false)
3956 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
3957 allow_empty) {}
3958
3968
3971
3974
3977 };
3978
3981
3986};
3987
3991 struct desc {
3993
4024 const memory::desc &diff_weights_desc,
4025 const memory::desc &diff_bias_desc,
4026 const memory::desc &diff_dst_desc, const memory::dims &strides,
4027 const memory::dims &padding_l, const memory::dims &padding_r) {
4028 memory::validate_dims(strides, src_desc.data.ndims - 2);
4029 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4030 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4033 convert_to_c(algorithm), &src_desc.data,
4034 &diff_weights_desc.data, &diff_bias_desc.data,
4035 &diff_dst_desc.data, &strides[0], &padding_l[0],
4036 &padding_r[0]),
4037 "could not create a descriptor for a convolution weights "
4038 "update primitive");
4039 }
4040
4068 const memory::desc &diff_weights_desc,
4069 const memory::desc &diff_dst_desc, const memory::dims &strides,
4070 const memory::dims &padding_l, const memory::dims &padding_r) {
4071 memory::validate_dims(strides, src_desc.data.ndims - 2);
4072 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4073 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4075 convert_to_c(algorithm), &src_desc.data,
4076 &diff_weights_desc.data, nullptr,
4077 &diff_dst_desc.data, &strides[0],
4078 &padding_l[0], &padding_r[0]),
4079 "could not create a descriptor for a convolution weights "
4080 "update primitive");
4081 }
4082
4115 const memory::desc &diff_weights_desc,
4116 const memory::desc &diff_bias_desc,
4117 const memory::desc &diff_dst_desc, const memory::dims &strides,
4118 const memory::dims &dilates, const memory::dims &padding_l,
4119 const memory::dims &padding_r) {
4120 memory::validate_dims(strides, src_desc.data.ndims - 2);
4121 memory::validate_dims(dilates, src_desc.data.ndims - 2);
4122 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4123 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4126 convert_to_c(algorithm), &src_desc.data,
4127 &diff_weights_desc.data, &diff_bias_desc.data,
4128 &diff_dst_desc.data, &strides[0], &dilates[0],
4129 &padding_l[0], &padding_r[0]),
4130 "could not create a descriptor for a dilated convolution "
4131 "weights gradient primitive");
4132 }
4133
4163 const memory::desc &diff_weights_desc,
4164 const memory::desc &diff_dst_desc, const memory::dims &strides,
4165 const memory::dims &dilates, const memory::dims &padding_l,
4166 const memory::dims &padding_r) {
4167 memory::validate_dims(strides, src_desc.data.ndims - 2);
4168 memory::validate_dims(dilates, src_desc.data.ndims - 2);
4169 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4170 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4173 convert_to_c(algorithm), &src_desc.data,
4174 &diff_weights_desc.data, nullptr,
4175 &diff_dst_desc.data, &strides[0], &dilates[0],
4176 &padding_l[0], &padding_r[0]),
4177 "could not create a descriptor for a dilated convolution "
4178 "weights gradient primitive");
4179 }
4180 };
4181
4185 primitive_desc() = default;
4186
4200 const convolution_forward::primitive_desc &hint_fwd_pd,
4201 bool allow_empty = false)
4202 : dnnl::primitive_desc(&desc.data, nullptr, engine,
4203 hint_fwd_pd.get(), allow_empty) {}
4204
4219 const engine &engine,
4220 const convolution_forward::primitive_desc &hint_fwd_pd,
4221 bool allow_empty = false)
4222 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
4223 allow_empty) {}
4224
4234
4236 memory::desc src_desc() const { return base::src_desc(0); }
4237
4242
4245
4251 return base::diff_weights_desc(1);
4252 }
4253 };
4254
4257
4262};
4263
4265//
4273
4277 struct desc {
4279
4312 const memory::desc &src_desc, const memory::desc &weights_desc,
4313 const memory::desc &bias_desc, const memory::desc &dst_desc,
4314 const memory::dims &strides, const memory::dims &padding_l,
4315 const memory::dims &padding_r) {
4316 memory::validate_dims(strides, src_desc.data.ndims - 2);
4317 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4318 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4322 convert_to_c(algorithm), &src_desc.data,
4323 &weights_desc.data, &bias_desc.data, &dst_desc.data,
4324 &strides[0], &padding_l[0], &padding_r[0]),
4325 "could not create a descriptor for a deconvolution forward "
4326 "propagation primitive");
4327 }
4328
4358 const memory::desc &src_desc, const memory::desc &weights_desc,
4359 const memory::desc &dst_desc, const memory::dims &strides,
4360 const memory::dims &padding_l, const memory::dims &padding_r) {
4361 memory::validate_dims(strides, src_desc.data.ndims - 2);
4362 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4363 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4367 convert_to_c(algorithm), &src_desc.data,
4368 &weights_desc.data, nullptr, &dst_desc.data,
4369 &strides[0], &padding_l[0], &padding_r[0]),
4370 "could not create a descriptor for a deconvolution forward "
4371 "propagation primitive");
4372 }
4373
4408 const memory::desc &src_desc, const memory::desc &weights_desc,
4409 const memory::desc &bias_desc, const memory::desc &dst_desc,
4410 const memory::dims &strides, const memory::dims &dilates,
4411 const memory::dims &padding_l, const memory::dims &padding_r) {
4412 memory::validate_dims(strides, src_desc.data.ndims - 2);
4413 memory::validate_dims(dilates, src_desc.data.ndims - 2);
4414 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4415 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4418 convert_to_c(algorithm), &src_desc.data,
4419 &weights_desc.data, &bias_desc.data,
4420 &dst_desc.data, &strides[0], &dilates[0],
4421 &padding_l[0], &padding_r[0]),
4422 "could not create a descriptor for a dilated deconvolution "
4423 "forward propagation primitive");
4424 }
4425
4457 const memory::desc &src_desc, const memory::desc &weights_desc,
4458 const memory::desc &dst_desc, const memory::dims &strides,
4459 const memory::dims &dilates, const memory::dims &padding_l,
4460 const memory::dims &padding_r) {
4461 memory::validate_dims(strides, src_desc.data.ndims - 2);
4462 memory::validate_dims(dilates, src_desc.data.ndims - 2);
4463 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4464 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4467 convert_to_c(algorithm), &src_desc.data,
4468 &weights_desc.data, nullptr,
4469 &dst_desc.data, &strides[0], &dilates[0],
4470 &padding_l[0], &padding_r[0]),
4471 "could not create a descriptor for a dilated deconvolution "
4472 "forward propagation primitive");
4473 }
4474 };
4475
4479 primitive_desc() = default;
4480
4492 bool allow_empty = false)
4494 &desc.data, nullptr, engine, nullptr, allow_empty) {}
4495
4508 const engine &engine, bool allow_empty = false)
4510 &desc.data, &attr, engine, nullptr, allow_empty) {}
4511
4522
4524 memory::desc src_desc() const { return base::src_desc(0); }
4525
4528
4530 memory::desc dst_desc() const { return base::dst_desc(0); }
4531
4534 };
4535
4538
4543};
4544
4548 struct desc {
4550
4576 desc(algorithm algorithm, const memory::desc &diff_src_desc,
4577 const memory::desc &weights_desc,
4578 const memory::desc &diff_dst_desc, const memory::dims &strides,
4579 const memory::dims &padding_l, const memory::dims &padding_r) {
4580 memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4581 memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4582 memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4585 convert_to_c(algorithm), &diff_src_desc.data,
4586 &weights_desc.data, &diff_dst_desc.data,
4587 &strides[0], &padding_l[0], &padding_r[0]),
4588 "could not create a descriptor for a deconvolution "
4589 "backward propagation primitive");
4590 }
4591
4619 desc(algorithm algorithm, const memory::desc &diff_src_desc,
4620 const memory::desc &weights_desc,
4621 const memory::desc &diff_dst_desc, const memory::dims &strides,
4622 const memory::dims &dilates, const memory::dims &padding_l,
4623 const memory::dims &padding_r) {
4624 memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4625 memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
4626 memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4627 memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4630 convert_to_c(algorithm), &diff_src_desc.data,
4631 &weights_desc.data, &diff_dst_desc.data,
4632 &strides[0], &dilates[0], &padding_l[0],
4633 &padding_r[0]),
4634 "could not create a descriptor for a dilated deconvolution "
4635 "backward propagation primitive");
4636 }
4637 };
4638
4642 primitive_desc() = default;
4643
4658 const deconvolution_forward::primitive_desc &hint_fwd_pd,
4659 bool allow_empty = false)
4660 : dnnl::primitive_desc(&desc.data, nullptr, engine,
4661 hint_fwd_pd.get(), allow_empty) {}
4662
4678 const engine &engine,
4679 const deconvolution_forward::primitive_desc &hint_fwd_pd,
4680 bool allow_empty = false)
4681 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
4682 allow_empty) {}
4683
4693
4696
4699
4702 };
4703
4706
4711};
4712
4716 struct desc {
4718
4748 const memory::desc &diff_weights_desc,
4749 const memory::desc &diff_bias_desc,
4750 const memory::desc &diff_dst_desc, const memory::dims &strides,
4751 const memory::dims &padding_l, const memory::dims &padding_r) {
4752 memory::validate_dims(strides, src_desc.data.ndims - 2);
4753 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4754 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4757 convert_to_c(algorithm), &src_desc.data,
4758 &diff_weights_desc.data, &diff_bias_desc.data,
4759 &diff_dst_desc.data, &strides[0], &padding_l[0],
4760 &padding_r[0]),
4761 "could not create a descriptor for a deconvolution weights "
4762 "update primitive");
4763 }
4764
4791 const memory::desc &diff_weights_desc,
4792 const memory::desc &diff_dst_desc, const memory::dims &strides,
4793 const memory::dims &padding_l, const memory::dims &padding_r) {
4794 memory::validate_dims(strides, src_desc.data.ndims - 2);
4795 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4796 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4798 &data, convert_to_c(algorithm),
4799 &src_desc.data, &diff_weights_desc.data,
4800 nullptr, &diff_dst_desc.data, &strides[0],
4801 &padding_l[0], &padding_r[0]),
4802 "could not create a descriptor for a deconvolution weights "
4803 "update primitive");
4804 }
4805
4837 const memory::desc &diff_weights_desc,
4838 const memory::desc &diff_bias_desc,
4839 const memory::desc &diff_dst_desc, const memory::dims &strides,
4840 const memory::dims &dilates, const memory::dims &padding_l,
4841 const memory::dims &padding_r) {
4842 memory::validate_dims(strides, src_desc.data.ndims - 2);
4843 memory::validate_dims(dilates, src_desc.data.ndims - 2);
4844 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4845 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4848 convert_to_c(algorithm), &src_desc.data,
4849 &diff_weights_desc.data, &diff_bias_desc.data,
4850 &diff_dst_desc.data, &strides[0], &dilates[0],
4851 &padding_l[0], &padding_r[0]),
4852 "could not create a descriptor for a dilated deconvolution "
4853 "weights gradient primitive");
4854 }
4855
4884 const memory::desc &diff_weights_desc,
4885 const memory::desc &diff_dst_desc, const memory::dims &strides,
4886 const memory::dims &dilates, const memory::dims &padding_l,
4887 const memory::dims &padding_r) {
4888 memory::validate_dims(strides, src_desc.data.ndims - 2);
4889 memory::validate_dims(dilates, src_desc.data.ndims - 2);
4890 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4891 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4894 convert_to_c(algorithm), &src_desc.data,
4895 &diff_weights_desc.data, nullptr,
4896 &diff_dst_desc.data, &strides[0], &dilates[0],
4897 &padding_l[0], &padding_r[0]),
4898 "could not create a descriptor for a dilated deconvolution "
4899 "weights gradient primitive");
4900 }
4901 };
4902
4906 primitive_desc() = default;
4907
4922 const deconvolution_forward::primitive_desc &hint_fwd_pd,
4923 bool allow_empty = false)
4924 : dnnl::primitive_desc(&desc.data, nullptr, engine,
4925 hint_fwd_pd.get(), allow_empty) {}
4926
4942 const engine &engine,
4943 const deconvolution_forward::primitive_desc &hint_fwd_pd,
4944 bool allow_empty = false)
4945 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
4946 allow_empty) {}
4947
4957
4959 memory::desc src_desc() const { return base::src_desc(0); }
4960
4965
4968
4971 return base::diff_weights_desc(1);
4972 }
4973 };
4974
4977
4982};
4983
4985
4994
4996struct lrn_forward : public primitive {
4998 struct desc {
4999 dnnl_lrn_desc_t data;
5000
5026 const memory::desc &data_desc, memory::dim local_size,
5027 float alpha, float beta, float k = 1.f) {
5030 convert_to_c(algorithm), &data_desc.data,
5031 local_size, alpha, beta, k),
5032 "could not create a descriptor for a lrn forward "
5033 "propagation primitive");
5034 }
5035 };
5036
5040 primitive_desc() = default;
5041
5052 bool allow_empty = false)
5054 &desc.data, nullptr, engine, nullptr, allow_empty) {}
5055
5067 const engine &engine, bool allow_empty = false)
5069 &desc.data, &attr, engine, nullptr, allow_empty) {}
5070
5081
5083 memory::desc src_desc() const { return base::src_desc(0); }
5084
5086 memory::desc dst_desc() const { return base::dst_desc(0); }
5087
5090 };
5091
5093 lrn_forward() = default;
5094
5099};
5100
5102struct lrn_backward : public primitive {
5104 struct desc {
5105 dnnl_lrn_desc_t data;
5106
5131 const memory::desc &diff_data_desc, memory::dim local_size,
5132 float alpha, float beta, float k = 1.f) {
5135 &diff_data_desc.data, &data_desc.data, local_size,
5136 alpha, beta, k),
5137 "could not create a descriptor for a lrn backward "
5138 "propagation primitive");
5139 }
5140 };
5141
5145 primitive_desc() = default;
5146
5160 const lrn_forward::primitive_desc &hint_fwd_pd,
5161 bool allow_empty = false)
5162 : dnnl::primitive_desc(&desc.data, nullptr, engine,
5163 hint_fwd_pd.get(), allow_empty) {}
5164
5179 const engine &engine,
5180 const lrn_forward::primitive_desc &hint_fwd_pd,
5181 bool allow_empty = false)
5182 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
5183 allow_empty) {}
5184
5194
5197
5200
5203 };
5204
5206 lrn_backward() = default;
5207
5212};
5213
5215
5223
5227 struct desc {
5229
5260 const memory::desc &src_desc, const memory::desc &dst_desc,
5261 const memory::dims &strides, const memory::dims &kernel,
5262 const memory::dims &padding_l, const memory::dims &padding_r) {
5263 memory::validate_dims(strides, src_desc.data.ndims - 2);
5264 memory::validate_dims(kernel, src_desc.data.ndims - 2);
5265 memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5266 memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5269 convert_to_c(algorithm), &src_desc.data,
5270 &dst_desc.data, &strides[0], &kernel[0],
5271 &padding_l[0], &padding_r[0]),
5272 "could not create a descriptor for a pooling forward "
5273 "propagation primitive");
5274 }
5275 };
5276
5280 primitive_desc() = default;
5281
5292 bool allow_empty = false)
5294 &desc.data, nullptr, engine, nullptr, allow_empty) {}
5295
5307 const engine &engine, bool allow_empty = false)
5309 &desc.data, &attr, engine, nullptr, allow_empty) {}
5310
5321
5323 memory::desc src_desc() const { return base::src_desc(0); }
5324
5326 memory::desc dst_desc() const { return base::dst_desc(0); }
5327
5330 };
5331
5333 pooling_forward() = default;
5334
5339};
5340
5344 struct desc {
5346
5372 desc(algorithm algorithm, const memory::desc &diff_src_desc,
5373 const memory::desc &diff_dst_desc, const memory::dims &strides,
5374 const memory::dims &kernel, const memory::dims &padding_l,
5375 const memory::dims &padding_r) {
5376 memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5377 memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
5378 memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5379 memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5382 convert_to_c(algorithm), &diff_src_desc.data,
5383 &diff_dst_desc.data, &strides[0], &kernel[0],
5384 &padding_l[0], &padding_r[0]),
5385 "could not create a descriptor for a pooling backward "
5386 "propagation primitive");
5387 }
5388 };
5389
5393 primitive_desc() = default;
5394
5408 const pooling_forward::primitive_desc &hint_fwd_pd,
5409 bool allow_empty = false)
5410 : dnnl::primitive_desc(&desc.data, nullptr, engine,
5411 hint_fwd_pd.get(), allow_empty) {}
5412
5427 const engine &engine,
5428 const pooling_forward::primitive_desc &hint_fwd_pd,
5429 bool allow_empty = false)
5430 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
5431 allow_empty) {}
5432
5442
5445
5448
5451 };
5452
5454 pooling_backward() = default;
5455
5460};
5461
5463
5484
5488 struct desc {
5490
5510 const memory::desc &data_desc, float alpha = 0,
5511 float beta = 0) {
5515 &data_desc.data, alpha, beta),
5516 "could not create a descriptor for an eltwise forward "
5517 "propagation primitive");
5518 }
5519 };
5520
5524 primitive_desc() = default;
5525
5537 bool allow_empty = false)
5539 &desc.data, nullptr, engine, nullptr, allow_empty) {}
5540
5553 const engine &engine, bool allow_empty = false)
5555 &desc.data, &attr, engine, nullptr, allow_empty) {}
5556
5567
5569 memory::desc src_desc() const { return base::src_desc(0); }
5570
5572 memory::desc dst_desc() const { return base::dst_desc(0); }
5573 };
5574
5576 eltwise_forward() = default;
5577
5582};
5583
5587 struct desc {
5589
5608 desc(algorithm algorithm, const memory::desc &diff_data_desc,
5609 const memory::desc &data_desc, float alpha = 0,
5610 float beta = 0) {
5613 dnnl::convert_to_c(algorithm), &diff_data_desc.data,
5614 &data_desc.data, alpha, beta),
5615 "could not create a descriptor for an eltwise backward "
5616 "propagation primitive");
5617 }
5618 };
5619
5623 primitive_desc() = default;
5624
5639 const eltwise_forward::primitive_desc &hint_fwd_pd,
5640 bool allow_empty = false)
5641 : dnnl::primitive_desc(&desc.data, nullptr, engine,
5642 hint_fwd_pd.get(), allow_empty) {}
5643
5659 const engine &engine,
5660 const eltwise_forward::primitive_desc &hint_fwd_pd,
5661 bool allow_empty = false)
5662 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
5663 allow_empty) {}
5664
5674
5676 memory::desc src_desc() const { return base::src_desc(0); }
5677
5680
5683 };
5684
5686 eltwise_backward() = default;
5687
5692};
5693
5695
5703
5707 struct desc {
5709
5711 desc() = default;
5712
5728 int softmax_axis) {
5731 &data_desc.data, softmax_axis),
5732 "could not create a descriptor for a softmax forward "
5733 "propagation primitive");
5734 }
5735 };
5736
5740 primitive_desc() = default;
5741
5753 bool allow_empty = false)
5755 &desc.data, nullptr, engine, nullptr, allow_empty) {}
5756
5769 const engine &engine, bool allow_empty = false)
5771 &desc.data, &attr, engine, nullptr, allow_empty) {}
5772
5783
5785 memory::desc src_desc() const { return base::src_desc(0); }
5786
5788 memory::desc dst_desc() const { return base::dst_desc(0); }
5789 };
5790
5792 softmax_forward() = default;
5793
5798};
5799
5803 struct desc {
5805
5807 desc() = default;
5808
5823 desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
5824 int softmax_axis) {
5826 dnnl_softmax_backward_desc_init(&data, &diff_data_desc.data,
5827 &data_desc.data, softmax_axis),
5828 "could not create a descriptor for a softmax backward "
5829 "propagation primitive");
5830 }
5831 };
5832
5836 primitive_desc() = default;
5837
5852 const softmax_forward::primitive_desc &hint_fwd_pd,
5853 bool allow_empty = false)
5854 : dnnl::primitive_desc(&desc.data, nullptr, engine,
5855 hint_fwd_pd.get(), allow_empty) {}
5856
5872 const engine &engine,
5873 const softmax_forward::primitive_desc &hint_fwd_pd,
5874 bool allow_empty = false)
5875 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
5876 allow_empty) {}
5877
5887
5889 memory::desc dst_desc() const { return base::dst_desc(0); }
5890
5893
5896 };
5897
5899 softmax_backward() = default;
5900
5905};
5906
5908
5916
5920 struct desc {
5922
5924 desc() = default;
5925
5941 int logsoftmax_axis) {
5944 &data_desc.data, logsoftmax_axis),
5945 "could not create a descriptor for a logsoftmax forward "
5946 "propagation primitive");
5947 }
5948 };
5949
5953 primitive_desc() = default;
5954
5966 bool allow_empty = false)
5968 &desc.data, nullptr, engine, nullptr, allow_empty) {}
5969
5982 const engine &engine, bool allow_empty = false)
5984 &desc.data, &attr, engine, nullptr, allow_empty) {}
5985
5993 : dnnl::primitive_desc(pd,
5994 // Logsoftmax and softmax share the implementation and
5995 // currently report the same primitive kind. Hence this
5996 // must be softmax and not logsoftmax.
6000
6002 memory::desc src_desc() const { return base::src_desc(0); }
6003
6005 memory::desc dst_desc() const { return base::dst_desc(0); }
6006 };
6007
6010
6015};
6016
6020 struct desc {
6022
6024 desc() = default;
6025
6040 desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
6041 int logsoftmax_axis) {
6043 &diff_data_desc.data, &data_desc.data,
6044 logsoftmax_axis),
6045 "could not create a descriptor for a logsoftmax backward "
6046 "propagation primitive");
6047 }
6048 };
6049
6053 primitive_desc() = default;
6054
6069 const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6070 bool allow_empty = false)
6071 : dnnl::primitive_desc(&desc.data, nullptr, engine,
6072 hint_fwd_pd.get(), allow_empty) {}
6073
6089 const engine &engine,
6090 const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6091 bool allow_empty = false)
6092 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
6093 allow_empty) {}
6094
6102 : dnnl::primitive_desc(pd,
6103 // Logsoftmax and softmax share the implementation and
6104 // currently report the same primitive kind. Hence this
6105 // must be softmax and not logsoftmax.
6108
6110 memory::desc dst_desc() const { return base::dst_desc(0); }
6111
6114
6117 };
6118
6121
6126};
6127
6129
6149
6153 struct desc {
6155
6199 desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon,
6200 normalization_flags flags) {
6203 dnnl::convert_to_c(prop_kind), &data_desc.data,
6204 epsilon, convert_to_c(flags)),
6205 "could not create a descriptor for a batch normalization "
6206 "forward propagation primitive");
6207 }
6208 };
6209
6214 primitive_desc() = default;
6215
6227 bool allow_empty = false)
6229 &desc.data, nullptr, engine, nullptr, allow_empty) {}
6230
6243 const engine &engine, bool allow_empty = false)
6245 &desc.data, &attr, engine, nullptr, allow_empty) {}
6246
6258
6260 memory::desc src_desc() const { return base::src_desc(0); }
6261
6263 memory::desc dst_desc() const { return base::dst_desc(0); }
6264
6267
6270
6273 memory::desc mean_desc() const { return stat_desc(mean); }
6274
6277 memory::desc variance_desc() const { return stat_desc(var); }
6278
6279 private:
6280 enum {
6281 mean = 1,
6282 var = 2,
6283 };
6284 memory::desc stat_desc(int kind) const {
6289 &p),
6290 "could not retrieve a descriptor from a primitive "
6291 "descriptor for batch normalization forward propagation "
6292 "primitive");
6293 return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6294 : query::dst_md,
6295 kind);
6296 }
6297 };
6298
6301
6306};
6307
6311 struct desc {
6313
6345 desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
6346 const memory::desc &data_desc, float epsilon,
6347 normalization_flags flags) {
6350 dnnl::convert_to_c(prop_kind), &diff_data_desc.data,
6351 &data_desc.data, epsilon, convert_to_c(flags)),
6352 "could not create a descriptor for a batch normalization "
6353 "backward propagation primitive");
6354 }
6355 };
6356
6361 primitive_desc() = default;
6362
6378 bool allow_empty = false)
6379 : dnnl::primitive_desc(&desc.data, nullptr, engine,
6380 hint_fwd_pd.get(), allow_empty) {}
6381
6397 const engine &engine,
6399 bool allow_empty = false)
6400 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
6401 allow_empty) {}
6402
6414
6416 memory::desc src_desc() const { return base::src_desc(0); }
6417
6420
6422 memory::desc dst_desc() const { return base::dst_desc(0); }
6423
6426
6429
6434
6437
6440 return query_md(query::src_md, 2);
6441 }
6442
6445 };
6446
6449
6454};
6455
6457
6479
6483 struct desc {
6485
6521 const memory::desc &stat_desc, float epsilon,
6522 normalization_flags flags) {
6525 dnnl::convert_to_c(prop_kind), &data_desc.data,
6526 &stat_desc.data, epsilon, convert_to_c(flags)),
6527 "could not create a descriptor for a layer normalization "
6528 "forward propagation primitive");
6529 }
6530
6564 desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon,
6565 normalization_flags flags) {
6568 dnnl::convert_to_c(prop_kind), &data_desc.data,
6569 nullptr, epsilon, convert_to_c(flags)),
6570 "could not create a descriptor for a layer normalization "
6571 "forward propagation primitive");
6572 }
6573 };
6574
6579 primitive_desc() = default;
6580
6592 bool allow_empty = false)
6594 &desc.data, nullptr, engine, nullptr, allow_empty) {}
6595
6608 const engine &engine, bool allow_empty = false)
6610 &desc.data, &attr, engine, nullptr, allow_empty) {}
6611
6623
6625 memory::desc src_desc() const { return base::src_desc(0); }
6626
6628 memory::desc dst_desc() const { return base::dst_desc(0); }
6629
6632
6635
6637 memory::desc mean_desc() const { return stat_desc(mean); }
6638
6640 memory::desc variance_desc() const { return stat_desc(var); }
6641
6642 private:
6643 enum {
6644 mean = 1,
6645 var = 2,
6646 };
6647 memory::desc stat_desc(int kind) const {
6652 &p),
6653 "could not retrieve a descriptor from a primitive "
6654 "descriptor for layer normalization forward propagation "
6655 "primitive");
6656 return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6657 : query::dst_md,
6658 kind);
6659 }
6660 };
6661
6664
6669};
6670
6674 struct desc {
6676
6706 desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
6707 const memory::desc &data_desc, const memory::desc &stat_desc,
6708 float epsilon, normalization_flags flags) {
6711 dnnl::convert_to_c(prop_kind), &diff_data_desc.data,
6712 &data_desc.data, &stat_desc.data, epsilon,
6713 convert_to_c(flags)),
6714 "could not create a descriptor for a batch normalization "
6715 "backward propagation primitive");
6716 }
6717
6746 desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
6747 const memory::desc &data_desc, float epsilon,
6748 normalization_flags flags) {
6751 &diff_data_desc.data, &data_desc.data,
6752 nullptr, epsilon, convert_to_c(flags)),
6753 "could not create a descriptor for a batch normalization "
6754 "backward propagation primitive");
6755 }
6756 };
6757
6762 primitive_desc() = default;
6763
6779 bool allow_empty = false)
6780 : dnnl::primitive_desc(&desc.data, nullptr, engine,
6781 hint_fwd_pd.get(), allow_empty) {}
6782
6798 const engine &engine,
6800 bool allow_empty = false)
6801 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
6802 allow_empty) {}
6803
6815
6817 memory::desc src_desc() const { return base::src_desc(0); }
6818
6821
6823 memory::desc dst_desc() const { return base::dst_desc(0); }
6824
6827
6830
6835
6838
6841 return query_md(query::src_md, 2);
6842 }
6843
6846 };
6847
6850
6855};
6856
6858
6866
6870 struct desc {
6872
6896 const memory::desc &weights_desc, const memory::desc &bias_desc,
6897 const memory::desc &dst_desc) {
6900 &src_desc.data, &weights_desc.data,
6901 &bias_desc.data, &dst_desc.data),
6902 "could not create a descriptor for an inner product "
6903 "forward propagation primitive");
6904 }
6905
6927 const memory::desc &weights_desc,
6928 const memory::desc &dst_desc) {
6931 dnnl::convert_to_c(prop_kind), &src_desc.data,
6932 &weights_desc.data, nullptr, &dst_desc.data),
6933 "could not create a descriptor for an inner product "
6934 "forward propagation primitive");
6935 }
6936 };
6937
6941 primitive_desc() = default;
6942
6954 bool allow_empty = false)
6956 &desc.data, nullptr, engine, nullptr, allow_empty) {}
6957
6970 const engine &engine, bool allow_empty = false)
6972 &desc.data, &attr, engine, nullptr, allow_empty) {}
6973
6984
6986 memory::desc src_desc() const { return base::src_desc(0); }
6987
6990
6992 memory::desc dst_desc() const { return base::dst_desc(0); }
6993
6996 };
6997
7000
7005};
7006
7010 struct desc {
7012
7030 desc(const memory::desc &diff_src_desc,
7031 const memory::desc &weights_desc,
7032 const memory::desc &diff_dst_desc) {
7034 &diff_src_desc.data, &weights_desc.data,
7035 &diff_dst_desc.data),
7036 "could not create a descriptor for an inner product "
7037 "backward propagation primitive");
7038 }
7039 };
7040
7045 primitive_desc() = default;
7046
7061 const inner_product_forward::primitive_desc &hint_fwd_pd,
7062 bool allow_empty = false)
7063 : dnnl::primitive_desc(&desc.data, nullptr, engine,
7064 hint_fwd_pd.get(), allow_empty) {}
7065
7081 const engine &engine,
7082 const inner_product_forward::primitive_desc &hint_fwd_pd,
7083 bool allow_empty = false)
7084 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
7085 allow_empty) {}
7086
7096
7099
7102
7105 };
7106
7109
7114};
7115
7119 struct desc {
7121
7141 desc(const memory::desc &src_desc,
7142 const memory::desc &diff_weights_desc,
7143 const memory::desc &diff_bias_desc,
7144 const memory::desc &diff_dst_desc) {
7147 &src_desc.data, &diff_weights_desc.data,
7148 &diff_bias_desc.data, &diff_dst_desc.data),
7149 "could not create a descriptor for an inner product "
7150 "weights gradient primitive");
7151 }
7152
7170 desc(const memory::desc &src_desc,
7171 const memory::desc &diff_weights_desc,
7172 const memory::desc &diff_dst_desc) {
7175 &src_desc.data, &diff_weights_desc.data, nullptr,
7176 &diff_dst_desc.data),
7177 "could not create a descriptor for an inner product "
7178 "weights gradient primitive");
7179 }
7180 };
7181
7185 primitive_desc() = default;
7186
7201 const inner_product_forward::primitive_desc &hint_fwd_pd,
7202 bool allow_empty = false)
7203 : dnnl::primitive_desc(&desc.data, nullptr, engine,
7204 hint_fwd_pd.get(), allow_empty) {}
7205
7221 const engine &engine,
7222 const inner_product_forward::primitive_desc &hint_fwd_pd,
7223 bool allow_empty = false)
7224 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
7225 allow_empty) {}
7226
7236
7238 memory::desc src_desc() const { return base::src_desc(0); }
7239
7244
7247
7250 return base::diff_weights_desc(1);
7251 }
7252 };
7253
7256
7261};
7262
7264
7272
7274struct rnn_primitive_desc_base : public primitive_desc {
7275 using primitive_desc::primitive_desc;
7276
7279
7290
7296
7304
7310
7316
7322
7328
7334
7342
7348
7356
7362
7368
7376
7382
7388
7394
7401
7408
7416
7422
7430
7436
7437protected:
7438 using rnn_base = rnn_primitive_desc_base;
7439
7440 // (Deliberately not using doxygen comments)
7441 //
7442 // Constructs an RNN primitive descriptor base from a C API primitive
7443 // descriptor while checking that it actually describes the expected
7444 // primitive by comparing propagation and primitive kinds. Caller can
7445 // pass two options propagation kinds. This is typically used to check
7446 // that propagation kind is inference or training forward propagation.
7447 //
7448 // @param pd C API primitive descriptor.
7449 // @param prop_kind1 Expected propagation kind.
7450 // @param prop_kind2 Expected propagation kind.
7451 // @param cell_kind Expected cell kind.
7453 dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
7454 dnnl::algorithm cell_kind) {
7456 dnnl_status_t rc;
7459 "could not retrieve a descriptor from a primitive descriptor "
7460 "for an RNN primitive");
7461
7462 dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
7463 dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
7464 dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
7465
7466 bool ok = rnn_d->primitive_kind == dnnl_rnn
7467 && (rnn_d->prop_kind == c_prop_kind1
7468 || rnn_d->prop_kind == c_prop_kind2)
7469 && rnn_d->cell_kind == c_cell_kind;
7470
7471 if (!ok)
7472 DNNL_THROW_ERROR(dnnl_invalid_arguments,
7473 "mismatch between expected and provided descriptors for an "
7474 "RNN primitive");
7475
7476 reset_with_clone(pd);
7477 }
7478};
7479
7483 struct desc {
7484 dnnl_rnn_desc_t data;
7485
7540 const memory::desc &src_layer_desc,
7541 const memory::desc &src_iter_desc,
7542 const memory::desc &weights_layer_desc,
7543 const memory::desc &weights_iter_desc,
7544 const memory::desc &bias_desc,
7545 const memory::desc &dst_layer_desc,
7546 const memory::desc &dst_iter_desc,
7547 rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7548 float beta = 0.0f) {
7552 dnnl::convert_to_c(activation),
7553 dnnl::convert_to_c(direction), &src_layer_desc.data,
7554 &src_iter_desc.data, &weights_layer_desc.data,
7555 &weights_iter_desc.data, &bias_desc.data,
7556 &dst_layer_desc.data, &dst_iter_desc.data,
7557 dnnl::convert_to_c(flags), alpha, beta),
7558 "could not create a descriptor for a vanilla RNN forward "
7559 "propagation primitive");
7560 }
7561 };
7562
7566 primitive_desc() = default;
7567
7579 bool allow_empty = false)
7581 &desc.data, nullptr, engine, nullptr, allow_empty) {}
7582
7595 const engine &engine, bool allow_empty = false)
7597 &desc.data, &attr, engine, nullptr, allow_empty) {}
7598
7609
7614
7617
7622
7627
7630
7635
7638
7643 };
7644
7647
7652};
7653
7657 struct desc {
7658 dnnl_rnn_desc_t data;
7659
7736 const memory::desc &src_layer_desc,
7737 const memory::desc &src_iter_desc,
7738 const memory::desc &weights_layer_desc,
7739 const memory::desc &weights_iter_desc,
7740 const memory::desc &bias_desc,
7741 const memory::desc &dst_layer_desc,
7742 const memory::desc &dst_iter_desc,
7743 const memory::desc &diff_src_layer_desc,
7744 const memory::desc &diff_src_iter_desc,
7745 const memory::desc &diff_weights_layer_desc,
7746 const memory::desc &diff_weights_iter_desc,
7747 const memory::desc &diff_bias_desc,
7748 const memory::desc &diff_dst_layer_desc,
7749 const memory::desc &diff_dst_iter_desc,
7750 rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7751 float beta = 0.0f) {
7755 dnnl::convert_to_c(activation),
7756 dnnl::convert_to_c(direction), &src_layer_desc.data,
7757 &src_iter_desc.data, &weights_layer_desc.data,
7758 &weights_iter_desc.data, &bias_desc.data,
7759 &dst_layer_desc.data, &dst_iter_desc.data,
7760 &diff_src_layer_desc.data, &diff_src_iter_desc.data,
7761 &diff_weights_layer_desc.data,
7762 &diff_weights_iter_desc.data, &diff_bias_desc.data,
7763 &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
7764 dnnl::convert_to_c(flags), alpha, beta),
7765 "could not create a descriptor for a vanilla RNN backward "
7766 "propagation primitive");
7767 }
7768 };
7769
7773 primitive_desc() = default;
7774
7789 const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
7790 bool allow_empty = false)
7791 : rnn_primitive_desc_base(&desc.data, nullptr, engine,
7792 hint_fwd_pd.get(), allow_empty) {}
7793
7809 const engine &engine,
7810 const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
7811 bool allow_empty = false)
7812 : rnn_primitive_desc_base(&desc.data, &attr, engine,
7813 hint_fwd_pd.get(), allow_empty) {}
7814
7824
7829
7832
7837
7842
7845
7850
7853
7858
7863
7868
7873
7878
7883
7888
7893 };
7894
7897
7902};
7903
7905struct lstm_forward : public primitive {
7907 struct desc {
7908 dnnl_rnn_desc_t data;
7909
7984 const memory::desc &src_layer_desc,
7985 const memory::desc &src_iter_desc,
7986 const memory::desc &src_iter_c_desc,
7987 const memory::desc &weights_layer_desc,
7988 const memory::desc &weights_iter_desc,
7989 const memory::desc &weights_peephole_desc,
7990 const memory::desc &weights_projection_desc,
7991 const memory::desc &bias_desc,
7992 const memory::desc &dst_layer_desc,
7993 const memory::desc &dst_iter_desc,
7994 const memory::desc &dst_iter_c_desc,
7995 rnn_flags flags = rnn_flags::undef) {
7999 dnnl::convert_to_c(direction), &src_layer_desc.data,
8000 &src_iter_desc.data, &src_iter_c_desc.data,
8001 &weights_layer_desc.data, &weights_iter_desc.data,
8002 &weights_peephole_desc.data,
8003 &weights_projection_desc.data, &bias_desc.data,
8004 &dst_layer_desc.data, &dst_iter_desc.data,
8005 &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8006 "could not create a descriptor for an LSTM forward "
8007 "propagation primitive");
8008 }
8009
8071 const memory::desc &src_layer_desc,
8072 const memory::desc &src_iter_desc,
8073 const memory::desc &src_iter_c_desc,
8074 const memory::desc &weights_layer_desc,
8075 const memory::desc &weights_iter_desc,
8076 const memory::desc &weights_peephole_desc,
8077 const memory::desc &bias_desc,
8078 const memory::desc &dst_layer_desc,
8079 const memory::desc &dst_iter_desc,
8080 const memory::desc &dst_iter_c_desc,
8081 rnn_flags flags = rnn_flags::undef) {
8085 dnnl::convert_to_c(direction), &src_layer_desc.data,
8086 &src_iter_desc.data, &src_iter_c_desc.data,
8087 &weights_layer_desc.data, &weights_iter_desc.data,
8088 &weights_peephole_desc.data, &bias_desc.data,
8089 &dst_layer_desc.data, &dst_iter_desc.data,
8090 &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8091 "could not create a descriptor for an LSTM forward "
8092 "propagation primitive");
8093 }
8094
8147 const memory::desc &src_layer_desc,
8148 const memory::desc &src_iter_desc,
8149 const memory::desc &src_iter_c_desc,
8150 const memory::desc &weights_layer_desc,
8151 const memory::desc &weights_iter_desc,
8152 const memory::desc &bias_desc,
8153 const memory::desc &dst_layer_desc,
8154 const memory::desc &dst_iter_desc,
8155 const memory::desc &dst_iter_c_desc,
8156 rnn_flags flags = rnn_flags::undef) {
8160 dnnl::convert_to_c(direction), &src_layer_desc.data,
8161 &src_iter_desc.data, &src_iter_c_desc.data,
8162 &weights_layer_desc.data, &weights_iter_desc.data,
8163 &bias_desc.data, &dst_layer_desc.data,
8164 &dst_iter_desc.data, &dst_iter_c_desc.data,
8165 dnnl::convert_to_c(flags)),
8166 "could not create a descriptor for an LSTM forward "
8167 "propagation primitive");
8168 }
8169 };
8170
8174 primitive_desc() = default;
8175
8186 bool allow_empty = false)
8188 &desc.data, nullptr, engine, nullptr, allow_empty) {}
8189
8201 const engine &engine, bool allow_empty = false)
8203 &desc.data, &attr, engine, nullptr, allow_empty) {}
8204
8215
8220
8223
8228
8233
8238
8243
8248
8251
8256
8259
8264
8269 };
8270
8272 lstm_forward() = default;
8273
8278};
8279
8281struct lstm_backward : public primitive {
8283 struct desc {
8284 dnnl_rnn_desc_t data;
8285
8414 const memory::desc &src_layer_desc,
8415 const memory::desc &src_iter_desc,
8416 const memory::desc &src_iter_c_desc,
8417 const memory::desc &weights_layer_desc,
8418 const memory::desc &weights_iter_desc,
8419 const memory::desc &weights_peephole_desc,
8420 const memory::desc &weights_projection_desc,
8421 const memory::desc &bias_desc,
8422 const memory::desc &dst_layer_desc,
8423 const memory::desc &dst_iter_desc,
8424 const memory::desc &dst_iter_c_desc,
8425 const memory::desc &diff_src_layer_desc,
8426 const memory::desc &diff_src_iter_desc,
8427 const memory::desc &diff_src_iter_c_desc,
8428 const memory::desc &diff_weights_layer_desc,
8429 const memory::desc &diff_weights_iter_desc,
8430 const memory::desc &diff_weights_peephole_desc,
8431 const memory::desc &diff_weights_projection_desc,
8432 const memory::desc &diff_bias_desc,
8433 const memory::desc &diff_dst_layer_desc,
8434 const memory::desc &diff_dst_iter_desc,
8435 const memory::desc &diff_dst_iter_c_desc,
8436 rnn_flags flags = rnn_flags::undef) {
8440 dnnl::convert_to_c(direction), &src_layer_desc.data,
8441 &src_iter_desc.data, &src_iter_c_desc.data,
8442 &weights_layer_desc.data, &weights_iter_desc.data,
8443 &weights_peephole_desc.data,
8444 &weights_projection_desc.data, &bias_desc.data,
8445 &dst_layer_desc.data, &dst_iter_desc.data,
8446 &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8447 &diff_src_iter_desc.data,
8448 &diff_src_iter_c_desc.data,
8449 &diff_weights_layer_desc.data,
8450 &diff_weights_iter_desc.data,
8451 &diff_weights_peephole_desc.data,
8452 &diff_weights_projection_desc.data,
8453 &diff_bias_desc.data, &diff_dst_layer_desc.data,
8454 &diff_dst_iter_desc.data,
8455 &diff_dst_iter_c_desc.data,
8456 dnnl::convert_to_c(flags)),
8457 "could not create a descriptor for an LSTM backward "
8458 "propagation primitive");
8459 }
8460
8562 const memory::desc &src_layer_desc,
8563 const memory::desc &src_iter_desc,
8564 const memory::desc &src_iter_c_desc,
8565 const memory::desc &weights_layer_desc,
8566 const memory::desc &weights_iter_desc,
8567 const memory::desc &weights_peephole_desc,
8568 const memory::desc &bias_desc,
8569 const memory::desc &dst_layer_desc,
8570 const memory::desc &dst_iter_desc,
8571 const memory::desc &dst_iter_c_desc,
8572 const memory::desc &diff_src_layer_desc,
8573 const memory::desc &diff_src_iter_desc,
8574 const memory::desc &diff_src_iter_c_desc,
8575 const memory::desc &diff_weights_layer_desc,
8576 const memory::desc &diff_weights_iter_desc,
8577 const memory::desc &diff_weights_peephole_desc,
8578 const memory::desc &diff_bias_desc,
8579 const memory::desc &diff_dst_layer_desc,
8580 const memory::desc &diff_dst_iter_desc,
8581 const memory::desc &diff_dst_iter_c_desc,
8582 rnn_flags flags = rnn_flags::undef) {
8586 dnnl::convert_to_c(direction), &src_layer_desc.data,
8587 &src_iter_desc.data, &src_iter_c_desc.data,
8588 &weights_layer_desc.data, &weights_iter_desc.data,
8589 &weights_peephole_desc.data, &bias_desc.data,
8590 &dst_layer_desc.data, &dst_iter_desc.data,
8591 &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8592 &diff_src_iter_desc.data,
8593 &diff_src_iter_c_desc.data,
8594 &diff_weights_layer_desc.data,
8595 &diff_weights_iter_desc.data,
8596 &diff_weights_peephole_desc.data,
8597 &diff_bias_desc.data, &diff_dst_layer_desc.data,
8598 &diff_dst_iter_desc.data,
8599 &diff_dst_iter_c_desc.data,
8600 dnnl::convert_to_c(flags)),
8601 "could not create a descriptor for an LSTM backward "
8602 "propagation primitive");
8603 }
8604
8688 const memory::desc &src_layer_desc,
8689 const memory::desc &src_iter_desc,
8690 const memory::desc &src_iter_c_desc,
8691 const memory::desc &weights_layer_desc,
8692 const memory::desc &weights_iter_desc,
8693 const memory::desc &bias_desc,
8694 const memory::desc &dst_layer_desc,
8695 const memory::desc &dst_iter_desc,
8696 const memory::desc &dst_iter_c_desc,
8697 const memory::desc &diff_src_layer_desc,
8698 const memory::desc &diff_src_iter_desc,
8699 const memory::desc &diff_src_iter_c_desc,
8700 const memory::desc &diff_weights_layer_desc,
8701 const memory::desc &diff_weights_iter_desc,
8702 const memory::desc &diff_bias_desc,
8703 const memory::desc &diff_dst_layer_desc,
8704 const memory::desc &diff_dst_iter_desc,
8705 const memory::desc &diff_dst_iter_c_desc,
8706 rnn_flags flags = rnn_flags::undef) {
8710 dnnl::convert_to_c(direction), &src_layer_desc.data,
8711 &src_iter_desc.data, &src_iter_c_desc.data,
8712 &weights_layer_desc.data, &weights_iter_desc.data,
8713 &bias_desc.data, &dst_layer_desc.data,
8714 &dst_iter_desc.data, &dst_iter_c_desc.data,
8715 &diff_src_layer_desc.data, &diff_src_iter_desc.data,
8716 &diff_src_iter_c_desc.data,
8717 &diff_weights_layer_desc.data,
8718 &diff_weights_iter_desc.data, &diff_bias_desc.data,
8719 &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
8720 &diff_dst_iter_c_desc.data,
8721 dnnl::convert_to_c(flags)),
8722 "could not create a descriptor for an LSTM backward "
8723 "propagation primitive");
8724 }
8725 };
8726
8730 primitive_desc() = default;
8731
8745 const lstm_forward::primitive_desc &hint_fwd_pd,
8746 bool allow_empty = false)
8747 : rnn_primitive_desc_base(&desc.data, nullptr, engine,
8748 hint_fwd_pd.get(), allow_empty) {}
8749
8764 const engine &engine,
8765 const lstm_forward::primitive_desc &hint_fwd_pd,
8766 bool allow_empty = false)
8767 : rnn_primitive_desc_base(&desc.data, &attr, engine,
8768 hint_fwd_pd.get(), allow_empty) {}
8769
8779
8784
8787
8792
8797
8802
8807
8812
8815
8820
8823
8828
8833
8838
8843
8848
8853
8858
8863
8868
8873
8878
8883
8888 };
8889
8891 lstm_backward() = default;
8892
8897};
8898
8900struct gru_forward : public primitive {
8902 struct desc {
8903 dnnl_rnn_desc_t data;
8904
8951 const memory::desc &src_layer_desc,
8952 const memory::desc &src_iter_desc,
8953 const memory::desc &weights_layer_desc,
8954 const memory::desc &weights_iter_desc,
8955 const memory::desc &bias_desc,
8956 const memory::desc &dst_layer_desc,
8957 const memory::desc &dst_iter_desc,
8958 rnn_flags flags = rnn_flags::undef) {
8962 dnnl::convert_to_c(direction), &src_layer_desc.data,
8963 &src_iter_desc.data, &weights_layer_desc.data,
8964 &weights_iter_desc.data, &bias_desc.data,
8965 &dst_layer_desc.data, &dst_iter_desc.data,
8966 dnnl::convert_to_c(flags)),
8967 "could not create a descriptor for a GRU forward "
8968 "propagation primitive");
8969 }
8970 };
8971
8975 primitive_desc() = default;
8976
8987 bool allow_empty = false)
8989 &desc.data, nullptr, engine, nullptr, allow_empty) {}
8990
9002 const engine &engine, bool allow_empty = false)
9004 &desc.data, &attr, engine, nullptr, allow_empty) {}
9005
9016
9021
9024
9029
9034
9037
9042
9045
9050 };
9051
9053 gru_forward() = default;
9054
9059};
9060
9062struct gru_backward : public primitive {
9064 struct desc {
9065 dnnl_rnn_desc_t data;
9066
9134 const memory::desc &src_layer_desc,
9135 const memory::desc &src_iter_desc,
9136 const memory::desc &weights_layer_desc,
9137 const memory::desc &weights_iter_desc,
9138 const memory::desc &bias_desc,
9139 const memory::desc &dst_layer_desc,
9140 const memory::desc &dst_iter_desc,
9141 const memory::desc &diff_src_layer_desc,
9142 const memory::desc &diff_src_iter_desc,
9143 const memory::desc &diff_weights_layer_desc,
9144 const memory::desc &diff_weights_iter_desc,
9145 const memory::desc &diff_bias_desc,
9146 const memory::desc &diff_dst_layer_desc,
9147 const memory::desc &diff_dst_iter_desc,
9148 rnn_flags flags = rnn_flags::undef) {
9152 dnnl::convert_to_c(direction), &src_layer_desc.data,
9153 &src_iter_desc.data, &weights_layer_desc.data,
9154 &weights_iter_desc.data, &bias_desc.data,
9155 &dst_layer_desc.data, &dst_iter_desc.data,
9156 &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9157 &diff_weights_layer_desc.data,
9158 &diff_weights_iter_desc.data, &diff_bias_desc.data,
9159 &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9160 dnnl::convert_to_c(flags)),
9161 "could not create a descriptor for a GRU backward "
9162 "propagation primitive");
9163 }
9164 };
9165
9169 primitive_desc() = default;
9170
9184 const gru_forward::primitive_desc &hint_fwd_pd,
9185 bool allow_empty = false)
9186 : rnn_primitive_desc_base(&desc.data, nullptr, engine,
9187 hint_fwd_pd.get(), allow_empty) {}
9188
9203 const engine &engine,
9204 const gru_forward::primitive_desc &hint_fwd_pd,
9205 bool allow_empty = false)
9206 : rnn_primitive_desc_base(&desc.data, &attr, engine,
9207 hint_fwd_pd.get(), allow_empty) {}
9208
9218
9223
9226
9231
9236
9239
9244
9247
9252
9257
9262
9267
9272
9277
9282
9287 };
9288
9290 gru_backward() = default;
9291
9296};
9297
9301 struct desc {
9302 dnnl_rnn_desc_t data;
9303
9350 const memory::desc &src_layer_desc,
9351 const memory::desc &src_iter_desc,
9352 const memory::desc &weights_layer_desc,
9353 const memory::desc &weights_iter_desc,
9354 const memory::desc &bias_desc,
9355 const memory::desc &dst_layer_desc,
9356 const memory::desc &dst_iter_desc,
9357 rnn_flags flags = rnn_flags::undef) {
9361 dnnl::convert_to_c(direction), &src_layer_desc.data,
9362 &src_iter_desc.data, &weights_layer_desc.data,
9363 &weights_iter_desc.data, &bias_desc.data,
9364 &dst_layer_desc.data, &dst_iter_desc.data,
9365 dnnl::convert_to_c(flags)),
9366 "could not create a descriptor for an LBR GRU forward "
9367 "propagation primitive");
9368 }
9369 };
9370
9374 primitive_desc() = default;
9375
9387 bool allow_empty = false)
9389 &desc.data, nullptr, engine, nullptr, allow_empty) {}
9390
9403 const engine &engine, bool allow_empty = false)
9405 &desc.data, &attr, engine, nullptr, allow_empty) {}
9406
9417
9422
9425
9430
9435
9438
9443
9446
9451 };
9452
9454 lbr_gru_forward() = default;
9455
9460};
9461
9465 struct desc {
9466 dnnl_rnn_desc_t data;
9467
9536 const memory::desc &src_layer_desc,
9537 const memory::desc &src_iter_desc,
9538 const memory::desc &weights_layer_desc,
9539 const memory::desc &weights_iter_desc,
9540 const memory::desc &bias_desc,
9541 const memory::desc &dst_layer_desc,
9542 const memory::desc &dst_iter_desc,
9543 const memory::desc &diff_src_layer_desc,
9544 const memory::desc &diff_src_iter_desc,
9545 const memory::desc &diff_weights_layer_desc,
9546 const memory::desc &diff_weights_iter_desc,
9547 const memory::desc &diff_bias_desc,
9548 const memory::desc &diff_dst_layer_desc,
9549 const memory::desc &diff_dst_iter_desc,
9550 rnn_flags flags = rnn_flags::undef) {
9554 dnnl::convert_to_c(direction), &src_layer_desc.data,
9555 &src_iter_desc.data, &weights_layer_desc.data,
9556 &weights_iter_desc.data, &bias_desc.data,
9557 &dst_layer_desc.data, &dst_iter_desc.data,
9558 &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9559 &diff_weights_layer_desc.data,
9560 &diff_weights_iter_desc.data, &diff_bias_desc.data,
9561 &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9562 dnnl::convert_to_c(flags)),
9563 "could not create a descriptor for an LBR GRU backward "
9564 "propagation primitive");
9565 }
9566 };
9567
9571 primitive_desc() = default;
9572
9587 const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9588 bool allow_empty = false)
9589 : rnn_primitive_desc_base(&desc.data, nullptr, engine,
9590 hint_fwd_pd.get(), allow_empty) {}
9591
9607 const engine &engine,
9608 const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9609 bool allow_empty = false)
9610 : rnn_primitive_desc_base(&desc.data, &attr, engine,
9611 hint_fwd_pd.get(), allow_empty) {}
9612
9622
9627
9630
9635
9640
9643
9648
9651
9656
9661
9666
9671
9676
9681
9686
9691 };
9692
9694 lbr_gru_backward() = default;
9695
9700};
9701
9703
9711
9715 struct desc {
9717
9733 desc(prop_kind prop_kind, const memory::desc &data_desc, int axis,
9734 int group_size) {
9737 &data_desc.data, axis, group_size),
9738 "could not create a descriptor for a shuffle forward "
9739 "propagation primitive");
9740 }
9741 };
9742
9746 primitive_desc() = default;
9747
9760 const primitive_attr &attr = primitive_attr(),
9761 bool allow_empty = false)
9763 &desc.data, &attr, engine, nullptr, allow_empty) {}
9764
9775
9777 memory::desc src_desc() const { return base::src_desc(0); }
9778
9780 memory::desc dst_desc() const { return base::dst_desc(0); }
9781 };
9782
9784 shuffle_forward() = default;
9785
9790};
9791
9796 struct desc {
9798
9812 desc(const memory::desc &diff_data_desc, int axis, int group_size) {
9814 &diff_data_desc.data, axis, group_size),
9815 "could not create a descriptor for a shuffle backward "
9816 "propagation primitive");
9817 }
9818 };
9819
9823 primitive_desc() = default;
9824
9840 const shuffle_forward::primitive_desc &hint_fwd_pd,
9841 const primitive_attr &attr = primitive_attr(),
9842 bool allow_empty = false)
9843 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
9844 allow_empty) {}
9845
9855
9858
9861 };
9862
9864 shuffle_backward() = default;
9865
9870};
9871
9873
9881
9883struct binary : public primitive {
9885 struct desc {
9888
9890 desc() = default;
9891
9907 const memory::desc &src1, const memory::desc &dst) {
9910 &src0.data, &src1.data, &dst.data),
9911 "could not create a descriptor for a binary operation "
9912 "primitive");
9913 }
9914 };
9915
9919 primitive_desc() = default;
9920
9931 bool allow_empty = false)
9933 &desc.data, nullptr, engine, nullptr, allow_empty) {}
9934
9946 const engine &engine, bool allow_empty = false)
9948 &desc.data, &attr, engine, nullptr, allow_empty) {}
9949
9956
9958 memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
9959
9962
9965
9967 memory::desc dst_desc() const { return base::dst_desc(0); }
9968 };
9969
9971 binary() = default;
9972
9976 binary(const primitive_desc &pd) : primitive(pd) {}
9977};
9978
9980
9990
9992struct matmul : public primitive {
9994 struct desc {
9995 dnnl_matmul_desc_t data;
9996
10009 desc(const memory::desc &src_desc, const memory::desc &weights_desc,
10010 const memory::desc &dst_desc) {
10012 dnnl_matmul_desc_init(&data, &src_desc.data,
10013 &weights_desc.data, nullptr, &dst_desc.data),
10014 "could not create a descriptor for a matmul primitive");
10015 }
10016
10031 desc(const memory::desc &src_desc, const memory::desc &weights_desc,
10032 const memory::desc &bias_desc, const memory::desc &dst_desc) {
10034 &weights_desc.data, &bias_desc.data,
10035 &dst_desc.data),
10036 "could not create a descriptor for a matmul primitive");
10037 }
10038 };
10039
10043 primitive_desc() = default;
10044
10054 bool allow_empty = false)
10056 &desc.data, nullptr, engine, nullptr, allow_empty) {}
10057
10068 const engine &engine, bool allow_empty = false)
10070 &desc.data, &attr, engine, nullptr, allow_empty) {}
10071
10078
10081
10084 return query_md(query::weights_md, 0);
10085 }
10086
10089 return query_md(query::weights_md, 1);
10090 }
10091
10094 };
10095
10097 matmul() = default;
10098
10101 matmul(const primitive_desc &pd) : primitive(pd) {}
10102};
10103
10105
10115
10119 struct desc {
10121
10134 //
10144 const memory::desc &src_desc, const memory::desc &dst_desc) {
10147 convert_to_c(algorithm), nullptr,
10148 &src_desc.data, &dst_desc.data),
10149 "could not create a resampling forward descriptor");
10150 }
10151
10167 const std::vector<float> &factors,
10168 const memory::desc &src_desc) {
10169 memory::validate_dims(factors, src_desc.data.ndims - 2);
10172 convert_to_c(algorithm), &factors[0],
10173 &src_desc.data, nullptr),
10174 "could not create a resampling forward descriptor");
10175 }
10176
10189 //
10200 const std::vector<float> &factors, const memory::desc &src_desc,
10201 const memory::desc &dst_desc) {
10202 if (!factors.empty())
10203 memory::validate_dims(factors, src_desc.data.ndims - 2);
10206 convert_to_c(algorithm), factors.data(),
10207 &src_desc.data, &dst_desc.data),
10208 "could not create a resampling forward descriptor");
10209 }
10210 };
10211
10215 primitive_desc() = default;
10216
10228 bool allow_empty = false)
10230 &desc.data, nullptr, engine, nullptr, allow_empty) {}
10231
10244 const engine &engine, bool allow_empty = false)
10246 &desc.data, &attr, engine, nullptr, allow_empty) {}
10247
10258
10260 memory::desc src_desc() const { return base::src_desc(0); }
10261
10263 memory::desc dst_desc() const { return base::dst_desc(0); }
10264 };
10265
10268
10273};
10274
10278 struct desc {
10280
10295 desc(algorithm algorithm, const memory::desc &diff_src_desc,
10296 const memory::desc &diff_dst_desc) {
10298 convert_to_c(algorithm), nullptr,
10299 &diff_src_desc.data, &diff_dst_desc.data),
10300 "could not create a resampling backward data descriptor");
10301 }
10302
10318 desc(algorithm algorithm, const std::vector<float> &factors,
10319 const memory::desc &diff_src_desc,
10320 const memory::desc &diff_dst_desc) {
10321 if (!factors.empty())
10322 memory::validate_dims(factors, diff_src_desc.data.ndims - 2);
10324 convert_to_c(algorithm), factors.data(),
10325 &diff_src_desc.data, &diff_dst_desc.data),
10326 "could not create a resampling backward data descriptor");
10327 }
10328 };
10329
10333 primitive_desc() = default;
10334
10349 const resampling_forward::primitive_desc &hint_fwd_pd,
10350 bool allow_empty = false)
10351 : dnnl::primitive_desc(&desc.data, nullptr, engine,
10352 hint_fwd_pd.get(), allow_empty) {}
10353
10369 const engine &engine,
10370 const resampling_forward::primitive_desc &hint_fwd_pd,
10371 bool allow_empty = false)
10372 : dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
10373 allow_empty) {}
10374
10384
10387
10390 };
10391
10394
10399};
10400
10402
10404
10410
10413
10431
10433inline status set_verbose(int level) {
10434 return static_cast<status>(dnnl_set_verbose(level));
10435}
10436
10438inline const version_t *version() {
10439 return dnnl_version();
10440}
10441
10443inline status set_jit_dump(int enable) {
10444 return static_cast<status>(dnnl_set_jit_dump(enable));
10445}
10446
10448inline status set_jit_profiling_flags(unsigned flags) {
10449 return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
10450}
10451
10453inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
10454 return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
10455}
10456
10478
10481 return static_cast<status>(
10482 dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
10483}
10484
10486
10493
10495inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
10496 dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
10497 const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
10498 return static_cast<status>(dnnl_sgemm(
10499 transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
10500}
10501
10503inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
10504 dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
10505 dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
10506 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
10507 return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
10508 K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
10509}
10510
10512inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
10513 dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
10514 dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
10515 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
10516 return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
10517 K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
10518}
10519
10520#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
10522inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
10523 dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
10524 const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc,
10526 return static_cast<status>(dnnl_sgemm_tp(
10527 transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, tp));
10528}
10530inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
10531 dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
10532 dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
10533 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
10534 dnnl::threadpool_iface *tp) {
10535 return static_cast<status>(dnnl_gemm_u8s8s32_tp(transa, transb, offsetc, M,
10536 N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co, tp));
10537}
10538
10540inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
10541 dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
10542 dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
10543 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
10544 dnnl::threadpool_iface *tp) {
10545 return static_cast<status>(dnnl_gemm_s8s8s32_tp(transa, transb, offsetc, M,
10546 N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co, tp));
10547}
10548#endif
10549
10551
10552// implementation section
10553
10556 dnnl_primitive_t result;
10558 "could not create a primitive");
10559 reset(result);
10560}
10561
10562inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
10563
10564inline void primitive::execute(const stream &stream,
10565 const std::unordered_map<int, memory> &args) const {
10566 std::vector<dnnl_exec_arg_t> c_args;
10567 c_args.reserve(args.size());
10568 for (const auto &a : args)
10569 c_args.push_back({a.first, a.second.get(true)});
10570
10571 error::wrap_c_api(dnnl_primitive_execute(get(), stream.get(),
10572 (int)c_args.size(), c_args.data()),
10573 "could not execute a primitive");
10574}
10576
10577#undef DNNL_DEFINE_BITMASK_OPS
10578
10579} // namespace dnnl
10580
10582
10583#endif
C API.
algorithm
Kinds of algorithms.
Definition dnnl.hpp:475
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(dnnl_primitive_attr_t attr, const float scale, const float shift)
Set quantization scale and shift parameters for RNN data tensors.
struct dnnl_primitive_attr * dnnl_primitive_attr_t
A primitive descriptor attributes handle that controls primitive behavior.
Definition dnnl_types.h:1734
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 2.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode)
Sets primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops)
Returns primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 1.
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops)
Destroys post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const int32_t *zero_points)
Sets primitive attributes zero points for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops)
Sets primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops, float scale)
Appends an accumulation (sum) to post-ops.
struct dnnl_post_ops * dnnl_post_ops_t
A post operation chain handle.
Definition dnnl_types.h:1760
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets quantization scaling factors for RNN weights tensors.
const struct dnnl_primitive_attr * const_dnnl_primitive_attr_t
A constant primitive descriptor attributes handle.
Definition dnnl_types.h:1737
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr)
Destroys primitive attributes.
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops)
Returns the length of post-ops.
const struct dnnl_post_ops * const_dnnl_post_ops_t
A constant post operation chain handle.
Definition dnnl_types.h:1763
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 2.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_alg_kind_t *alg_kind, float *alpha, float *beta)
Returns the parameters of an elementwise post-up.
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops)
Creates empty post-ops sequence.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const float *scales)
Sets primitive attributes scaling factors for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode)
Returns the primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr)
Clones primitive attributes.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 1.
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(const_dnnl_post_ops_t post_ops, int index)
Returns the kind of a post-op entry.
scratchpad_mode
Scratchpad mode.
Definition dnnl.hpp:406
prop_kind
Propagation kind.
Definition dnnl.hpp:440
dnnl_scratchpad_mode_t
Scratchpad mode.
Definition dnnl_types.h:1700
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta)
Appends an elementwise post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points(const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const int32_t **zero_points)
Returns count, correspondence zero point mask, and a pointer to a constant int32_t array of zero_poin...
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(const_dnnl_post_ops_t post_ops, int index, float *scale)
Returns the parameters of an accumulation (sum) post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets output scaling factors correspondence mask and values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes scaling factors correspondence mask and values for a given memory argume...
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr)
Creates an empty (default) primitive attributes with all the parameters set to their default values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes output scaling factors correspondence mask and values.
@ resampling_linear
Linear (Bilinear, Trilinear) resampling method.
Definition dnnl.hpp:574
@ binary_mul
Binary mul.
Definition dnnl.hpp:566
@ resampling_nearest
Nearest Neighbor resampling method.
Definition dnnl.hpp:572
@ eltwise_elu_use_dst_for_bwd
Elementwise: exponential linear unit (ELU) (dst for backward)
Definition dnnl.hpp:531
@ eltwise_tanh_use_dst_for_bwd
Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
Definition dnnl.hpp:529
@ eltwise_linear
Elementwise: linear.
Definition dnnl.hpp:504
@ eltwise_soft_relu
Elementwise: soft_relu.
Definition dnnl.hpp:508
@ vanilla_gru
GRU cell.
Definition dnnl.hpp:556
@ eltwise_logistic
Elementwise: logistic.
Definition dnnl.hpp:510
@ eltwise_clip
Elementwise: clip.
Definition dnnl.hpp:523
@ eltwise_abs
Elementwise: abs.
Definition dnnl.hpp:498
@ eltwise_pow
Elementwise: pow.
Definition dnnl.hpp:525
@ eltwise_tanh
Elementwise: hyperbolic tangent non-linearity (tanh)
Definition dnnl.hpp:492
@ eltwise_logistic_use_dst_for_bwd
Elementwise: logistic (dst for backward)
Definition dnnl.hpp:535
@ eltwise_bounded_relu
Elementwise: bounded_relu.
Definition dnnl.hpp:506
@ eltwise_square
Elementwise: square.
Definition dnnl.hpp:496
@ binary_max
Binary max.
Definition dnnl.hpp:568
@ convolution_direct
Direct convolution.
Definition dnnl.hpp:482
@ eltwise_exp
Elementwise: exponent.
Definition dnnl.hpp:512
@ eltwise_elu
Elementwise: exponential linear unit (ELU)
Definition dnnl.hpp:494
@ convolution_winograd
Winograd convolution.
Definition dnnl.hpp:484
@ vanilla_lstm
LSTM cell.
Definition dnnl.hpp:554
@ deconvolution_direct
Direct deconvolution.
Definition dnnl.hpp:486
@ pooling_avg
Average pooling exclude padding, alias for dnnl::algorithm::pooling_avg_include_padding.
Definition dnnl.hpp:546
@ lbr_gru
GRU cell with linear before reset.
Definition dnnl.hpp:562
@ pooling_avg_exclude_padding
Average pooling exclude padding.
Definition dnnl.hpp:550
@ eltwise_gelu
Elementwise: gelu alias for dnnl::algorithm::eltwise_gelu_tanh.
Definition dnnl.hpp:515
@ eltwise_sqrt
Elementwise: square root.
Definition dnnl.hpp:500
@ pooling_max
Max pooling.
Definition dnnl.hpp:543
@ eltwise_gelu_erf
Elementwise: erf-based gelu.
Definition dnnl.hpp:519
@ eltwise_swish
Elementwise: swish ( )
Definition dnnl.hpp:502
@ lrn_within_channel
LRN within a single channel.
Definition dnnl.hpp:541
@ vanilla_rnn
RNN cell.
Definition dnnl.hpp:552
@ binary_add
Binary add.
Definition dnnl.hpp:564
@ lrn_across_channels
Local response normalization (LRN) across multiple channels.
Definition dnnl.hpp:539
@ eltwise_relu
Elementwise: rectified linear unit (ReLU)
Definition dnnl.hpp:490
@ eltwise_gelu_tanh
Elementwise: tanh-based gelu.
Definition dnnl.hpp:517
@ eltwise_relu_use_dst_for_bwd
Elementwise: rectified linar unit (ReLU) (dst for backward)
Definition dnnl.hpp:527
@ convolution_auto
Convolution algorithm that is chosen to be either direct or Winograd automatically.
Definition dnnl.hpp:480
@ binary_min
Binary min.
Definition dnnl.hpp:570
@ eltwise_exp_use_dst_for_bwd
Elementwise: exponent (dst for backward)
Definition dnnl.hpp:537
@ eltwise_sqrt_use_dst_for_bwd
Elementwise: square root (dst for backward)
Definition dnnl.hpp:533
@ pooling_avg_include_padding
Average pooling include padding.
Definition dnnl.hpp:548
@ deconvolution_winograd
Winograd deconvolution.
Definition dnnl.hpp:488
@ eltwise_log
Elementwise: natural logarithm.
Definition dnnl.hpp:521
@ library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
Definition dnnl.hpp:423
@ user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
Definition dnnl.hpp:428
@ backward
Backward propagation (with respect to all parameters).
Definition dnnl.hpp:457
@ backward_weights
Backward weights propagation.
Definition dnnl.hpp:461
@ forward_training
Forward data propagation (training mode).
Definition dnnl.hpp:445
@ forward_inference
Forward data propagation (inference mode).
Definition dnnl.hpp:449
@ forward_scoring
Forward data propagation, alias for dnnl::prop_kind::forward_inference.
Definition dnnl.hpp:452
@ forward
Forward data propagation, alias for dnnl::prop_kind::forward_training.
Definition dnnl.hpp:455
@ backward_data
Backward data propagation.
Definition dnnl.hpp:459
@ backward_bias
Backward bias propagation.
Definition dnnl.hpp:463
@ undef
Undefined propagation kind.
Definition dnnl.hpp:442
@ dnnl_scratchpad_mode_user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
Definition dnnl_types.h:1722
@ dnnl_scratchpad_mode_library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
Definition dnnl_types.h:1717
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization forward propagation primitive.
dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc, const dnnl_memory_desc_t *src1_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a binary primitive.
dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
Definition dnnl.hpp:10503
status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
Definition dnnl.hpp:10512
dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
Definition dnnl.hpp:10495
dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(dnnl_primitive_desc_t *concat_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an out-of-place concatenation primitive.
dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution weights gradient primitive.
dnnl_convolution_desc_t dnnl_deconvolution_desc_t
A descriptor of a deconvolution operation.
Definition dnnl_types.h:1179
dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise forward propagation primitive.
dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise backward propagation primitive.
dnnl_engine_kind_t
Kinds of engines.
Definition dnnl_types.h:1646
dnnl_status_t DNNL_API dnnl_engine_get_ocl_device(dnnl_engine_t engine, cl_device_id *device)
Returns the OpenCL device associated with an engine.
struct dnnl_engine * dnnl_engine_t
An engine handle.
Definition dnnl_types.h:1659
dnnl_status_t DNNL_API dnnl_engine_create_ocl(dnnl_engine_t *engine, dnnl_engine_kind_t kind, cl_device_id device, cl_context context)
Creates an engine associated with an OpenCL device and an OpenCL context.
dnnl_status_t DNNL_API dnnl_engine_get_kind(dnnl_engine_t engine, dnnl_engine_kind_t *kind)
Returns the kind of an engine.
dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine)
Destroys an engine.
dnnl_status_t DNNL_API dnnl_engine_get_ocl_context(dnnl_engine_t engine, cl_context *context)
Returns the OpenCL context associated with an engine.
dnnl_status_t DNNL_API dnnl_engine_create(dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index)
Creates an engine.
size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind)
Returns the number of engines of a particular kind.
@ dnnl_gpu
GPU engine.
Definition dnnl_types.h:1652
@ dnnl_cpu
CPU engine.
Definition dnnl_types.h:1650
@ dnnl_any_engine
An unspecified engine.
Definition dnnl_types.h:1648
dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes descriptor for inner product forward propagation.
dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product weights gradient primitive.
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product backward propagation.
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for a layer normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for layer normalization forward propagation primitive.
dnnl_softmax_desc_t dnnl_logsoftmax_desc_t
A descriptor of a LogSoftmax operation.
Definition dnnl_types.h:1283
dnnl_status_t DNNL_API dnnl_logsoftmax_forward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax forward propagation primitive.
dnnl_status_t DNNL_API dnnl_logsoftmax_backward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_matmul_desc_init(dnnl_matmul_desc_t *matmul_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a matrix multiplication descriptor.
dnnl_data_type_t
Data type specification.
Definition dnnl_types.h:62
dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(dnnl_memory_desc_t *memory_desc, const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims, const dnnl_dims_t offsets)
Initializes a memory descriptor for a region inside an area described by an existing memory descripto...
struct dnnl_memory * dnnl_memory_t
A memory handle.
Definition dnnl_types.h:1104
dnnl_format_tag_t
Memory format tag specification.
Definition dnnl_types.h:164
dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, const int *permutation)
Initializes a memory descriptor by permuting axes in an existing one.
dnnl_status_t DNNL_API dnnl_memory_unmap_data(const_dnnl_memory_t memory, void *mapped_ptr)
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine, void *handle)
Creates a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_engine(const_dnnl_memory_t memory, dnnl_engine_t *engine)
Returns the engine of a memory object.
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, int ndims, const dnnl_dims_t dims)
Initializes a memory descriptor by reshaping an existing one.
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc)
Returns the memory descriptor for a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_data_handle(const_dnnl_memory_t memory, void **handle)
Returns memory object's data handle.
dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(dnnl_memory_t memory, void *handle, dnnl_stream_t stream)
Sets a memory object's data handle.
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, const dnnl_dims_t strides)
Initializes a memory descriptor using dimensions and strides.
int64_t dnnl_dim_t
A type to describe tensor dimension.
Definition dnnl_types.h:944
dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory)
Destroys a memory object.
int DNNL_API dnnl_memory_desc_equal(const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs)
Compares two memory descriptors.
#define DNNL_MAX_NDIMS
Maximum number of dimensions a tensor can have.
Definition dnnl_types.h:912
dnnl_status_t DNNL_API dnnl_memory_map_data(const_dnnl_memory_t memory, void **mapped_ptr)
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
dnnl_status_t DNNL_API dnnl_memory_set_ocl_mem_object(dnnl_memory_t memory, cl_mem mem_object)
Sets OpenCL memory object associated with a memory object.
size_t DNNL_API dnnl_memory_desc_get_size(const dnnl_memory_desc_t *memory_desc)
Returns the size of a memory descriptor.
dnnl_status_t DNNL_API dnnl_memory_get_ocl_mem_object(const_dnnl_memory_t memory, cl_mem *mem_object)
Returns an OpenCL memory object associated with a memory object.
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, dnnl_format_tag_t tag)
Initializes a memory descriptor using dimensions and memory format tag.
@ dnnl_f16
16-bit/half-precision floating point.
Definition dnnl_types.h:66
@ dnnl_bf16
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
Definition dnnl_types.h:68
@ dnnl_f32
32-bit/single-precision floating point.
Definition dnnl_types.h:70
@ dnnl_data_type_undef
Undefined data type, used for empty memory descriptors.
Definition dnnl_types.h:64
@ dnnl_s8
8-bit signed integer.
Definition dnnl_types.h:74
@ dnnl_s32
32-bit signed integer.
Definition dnnl_types.h:72
@ dnnl_u8
8-bit unsigned integer.
Definition dnnl_types.h:76
@ dnnl_aBCdef2b4c2b
6D tensor blocked by 3rd dimension with block size 4
Definition dnnl_types.h:317
@ dnnl_acdeb
permuted 5D tensor
Definition dnnl_types.h:192
@ dnnl_ab
plain 2D tensor
Definition dnnl_types.h:178
@ dnnl_ABcd8b8a
4D tensor blocked by 1st and 2nd dimension with block size 8
Definition dnnl_types.h:258
@ dnnl_cdba
permuted 4D tensor
Definition dnnl_types.h:200
@ dnnl_aBcdef4b
6D tensor blocked by 2nd dimension with block size 4
Definition dnnl_types.h:319
@ dnnl_aBcd4b
4D tensor blocked by 2nd dimension with block size 4
Definition dnnl_types.h:238
@ dnnl_nCdhw16c
5D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcde16b
Definition dnnl_types.h:488
@ dnnl_abcde
plain 5D tensor
Definition dnnl_types.h:181
@ dnnl_decab
permuted 5D tensor
Definition dnnl_types.h:203
@ dnnl_bca
permuted 3D tensor
Definition dnnl_types.h:196
@ dnnl_aBcde4b
5D tensor blocked by 2nd dimension with block size 4
Definition dnnl_types.h:281
@ dnnl_aBc16b
3D tensor blocked by 2nd dimension with block size 16
Definition dnnl_types.h:212
@ dnnl_aBcdef16b
6D tensor blocked by 2nd dimension with block size 16
Definition dnnl_types.h:309
@ dnnl_aBCde2b4c2b
5D tensor blocked by 3rd dimension with block size 4
Definition dnnl_types.h:307
@ dnnl_aBc4b
3D tensor blocked by 2nd dimension with block size 4
Definition dnnl_types.h:216
@ dnnl_aBcd16b
4D tensor blocked by 2nd dimension with block size 16
Definition dnnl_types.h:232
@ dnnl_cba
permuted 3D tensor
Definition dnnl_types.h:199
@ dnnl_ba
permuted 2D tensor
Definition dnnl_types.h:193
@ dnnl_ABcde2b8a4b
5D tensor blocked by 1st dimension with block size 8
Definition dnnl_types.h:272
@ dnnl_abcd
plain 4D tensor
Definition dnnl_types.h:180
@ dnnl_format_tag_undef
Undefined memory format tag.
Definition dnnl_types.h:166
@ dnnl_nCdhw4c
5D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcde4b
Definition dnnl_types.h:491
@ dnnl_defcab
permuted 6D tensor
Definition dnnl_types.h:204
@ dnnl_abcdef
plain 6D tensor
Definition dnnl_types.h:182
@ dnnl_nChw8c
4D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcd8b
Definition dnnl_types.h:503
@ dnnl_a
plain 1D tensor
Definition dnnl_types.h:177
@ dnnl_nChw4c
4D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcd4b
Definition dnnl_types.h:500
@ dnnl_acbdef
permuted 6D tensor
Definition dnnl_types.h:190
@ dnnl_acdb
permuted 4D tensor
Definition dnnl_types.h:191
@ dnnl_aBcd8b
4D tensor blocked by 2nd dimension with block size 8
Definition dnnl_types.h:252
@ dnnl_aBc8b
3D tensor blocked by 2nd dimension with block size 8
Definition dnnl_types.h:223
@ dnnl_nCw4c
3D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBc4b
Definition dnnl_types.h:509
@ dnnl_aBcde8b
5D tensor blocked by 2nd dimension with block size 8
Definition dnnl_types.h:293
@ dnnl_nChw16c
4D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcd16b
Definition dnnl_types.h:497
@ dnnl_abdec
permuted 5D tensor
Definition dnnl_types.h:187
@ dnnl_bacd
permuted 4D tensor
Definition dnnl_types.h:195
@ dnnl_nCdhw8c
5D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcde8b
Definition dnnl_types.h:494
@ dnnl_bcda
permuted 4D tensor
Definition dnnl_types.h:197
@ dnnl_acbde
permuted 5D tensor
Definition dnnl_types.h:189
@ dnnl_aBCd2b4c2b
4D tensor blocked by 3rd dimension with block size 4
Definition dnnl_types.h:268
@ dnnl_bcdea
permuted 5D tensor
Definition dnnl_types.h:198
@ dnnl_aBcde16b
5D tensor blocked by 2nd dimension with block size 16
Definition dnnl_types.h:274
@ dnnl_nCw8c
3D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBc8b
Definition dnnl_types.h:512
@ dnnl_abdc
permuted 4D tensor
Definition dnnl_types.h:186
@ dnnl_ABcde4b16a4b
5D tensor blocked by 1st dimension with block size 16
Definition dnnl_types.h:270
@ dnnl_format_tag_last
Just a sentinel, not real memory format tag.
Definition dnnl_types.h:370
@ dnnl_abc
plain 3D tensor
Definition dnnl_types.h:179
@ dnnl_bac
permuted 3D tensor
Definition dnnl_types.h:194
@ dnnl_dcab
permuted 4D tensor
Definition dnnl_types.h:201
@ dnnl_cdeba
permuted 5D tensor
Definition dnnl_types.h:202
@ dnnl_acb
permuted 3D tensor
Definition dnnl_types.h:188
@ dnnl_nCw16c
3D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBc16b
Definition dnnl_types.h:506
@ dnnl_format_tag_any
Undefined memory format tag.
Definition dnnl_types.h:169
@ dnnl_blocked
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition dnnl_types.h:89
@ dnnl_format_kind_wino
Weights format used in 8bit Winograd convolution.
Definition dnnl_types.h:91
@ dnnl_format_kind_any
Unspecified format kind.
Definition dnnl_types.h:85
@ dnnl_format_kind_undef
Undefined memory format kind, used for empty memory descriptors.
Definition dnnl_types.h:82
@ dnnl_format_kind_rnn_packed
Packed weights format used in RNN.
Definition dnnl_types.h:93
dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_primitive_desc_query(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index, void *result)
Queries a primitive descriptor for various pieces of information.
#define DNNL_ARG_DST_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition dnnl_types.h:1817
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(dnnl_primitive_desc_iterator_t iterator)
Destroys a primitive descriptor iterator.
#define DNNL_ARG_WEIGHTS_LAYER
A special mnemonic for RNN weights applied to the layer input.
Definition dnnl_types.h:1835
#define DNNL_ARG_DIFF_BIAS
Gradient (diff) of the bias tensor argument.
Definition dnnl_types.h:1942
#define DNNL_ARG_DIFF_SRC_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition dnnl_types.h:1888
struct dnnl_primitive * dnnl_primitive_t
A primitive handle.
Definition dnnl_types.h:1774
#define DNNL_ARG_DIFF_SRC_LAYER
A special mnemonic for gradient (diff) of RNN input vector.
Definition dnnl_types.h:1876
#define DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE
A special mnemonic for diff of RNN weights applied to the peephole weights.
Definition dnnl_types.h:1933
#define DNNL_ARG_WEIGHTS_PROJECTION
A special mnemonic for RNN weights applied to the projection weights.
Definition dnnl_types.h:1853
struct dnnl_primitive_desc_iterator * dnnl_primitive_desc_iterator_t
A primitive descriptor iterator handle.
Definition dnnl_types.h:1678
dnnl_normalization_flags_t
Flags for normalization primitives.
Definition dnnl_types.h:852
dnnl_primitive_kind_t convert_to_c(primitive::kind kind)
Converts primitive kind enum value from C++ API to C API type.
Definition dnnl.hpp:369
#define DNNL_ARG_DIFF_WEIGHTS_PROJECTION
A special mnemonic for diff of RNN weights applied to the projection weights.
Definition dnnl_types.h:1939
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(const_dnnl_primitive_desc_t primitive_desc, const_dnnl_primitive_attr_t *attr)
Returns a constant reference to the attributes of a primitive descriptor.
#define DNNL_ARG_DIFF_WEIGHTS_ITER
A special mnemonic for diff of RNN weights applied to the recurrent input.
Definition dnnl_types.h:1927
#define DNNL_ARG_DIFF_SRC_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition dnnl_types.h:1882
#define DNNL_ARG_DIFF_DST_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition dnnl_types.h:1909
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args)
Executes a primitive.
#define DNNL_ARG_WEIGHTS_ITER
A special mnemonic for RNN weights applied to the recurrent input.
Definition dnnl_types.h:1841
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(dnnl_primitive_desc_iterator_t iterator)
Advances the primitive descriptor iterator to point to the next available implementation.
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(dnnl_primitive_desc_t primitive_desc)
Destroys a primitive descriptor.
const void * const_dnnl_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition dnnl_types.h:1122
const_dnnl_primitive_desc_t get_primitive_desc() const
Returns the C API primitive descriptor of the underlying C API primitive.
Definition dnnl.hpp:373
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(const_dnnl_primitive_t primitive, const_dnnl_primitive_desc_t *primitive_desc)
Retrieves a constant reference to the primitive descriptor of a given primitive.
#define DNNL_ARG_DST_ITER_C
A special mnemonic for LSTM output recurrent cell state vector.
Definition dnnl_types.h:1823
#define DNNL_ARG_SRC_ITER_C
A special mnemonic for RNN input recurrent cell state vector.
Definition dnnl_types.h:1800
query
Primitive descriptor query specification.
Definition dnnl.hpp:720
#define DNNL_ARG_FROM
A special mnemonic for reorder source argument.
Definition dnnl_types.h:1788
dnnl_alg_kind_t
Kinds of algorithms.
Definition dnnl_types.h:748
dnnl_primitive_kind_t
Kinds of primitives.
Definition dnnl_types.h:704
dnnl_query_t
Primitive descriptor query specification.
Definition dnnl_types.h:2002
const dnnl_memory_desc_t DNNL_API * dnnl_primitive_desc_query_md(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index)
Queries primitive descriptor for a memory descriptor.
struct dnnl_primitive_desc * dnnl_primitive_desc_t
A primitive descriptor handle.
Definition dnnl_types.h:1689
#define DNNL_ARG_WEIGHTS_PEEPHOLE
A special mnemonic for RNN weights applied to the peephole weights.
Definition dnnl_types.h:1847
const struct dnnl_primitive_desc * const_dnnl_primitive_desc_t
A constant primitive descriptor handle.
Definition dnnl_types.h:1692
kind get_kind() const
Returns the kind of the primitive.
Definition dnnl.hpp:380
#define DNNL_ARG_SRC_LAYER
A special mnemonic for RNN input vector.
Definition dnnl_types.h:1785
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive)
Destroys a primitive.
#define DNNL_ARG_DIFF_WEIGHTS_LAYER
A special mnemonic for diff of RNN weights applied to the layer input.
Definition dnnl_types.h:1921
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, const_dnnl_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator.
#define DNNL_ARG_DST_LAYER
A special mnemonic for RNN output vector. An alias for DNNL_ARG_DST_0.
Definition dnnl_types.h:1811
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc)
Creates a primitive.
#define DNNL_ARG_BIAS
Bias tensor argument.
Definition dnnl_types.h:1856
normalization_flags
Flags for normalization primitives.
Definition dnnl.hpp:590
#define DNNL_ARG_DIFF_DST_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition dnnl_types.h:1903
dnnl_prop_kind_t
Kinds of propagation.
Definition dnnl_types.h:677
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(dnnl_primitive_desc_t *primitive_desc, const_dnnl_primitive_desc_t existing_primitive_desc)
Clones a primitive descriptor.
#define DNNL_ARG_SRC_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition dnnl_types.h:1794
dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(const_dnnl_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor from a primitive descriptor iterator.
#define DNNL_ARG_TO
A special mnemonic for reorder destination argument.
Definition dnnl_types.h:1809
#define DNNL_ARG_DIFF_DST_LAYER
A special mnemonic for gradient (diff) of RNN output vector.
Definition dnnl_types.h:1897
@ dnnl_fuse_norm_relu
Fuse with ReLU.
Definition dnnl_types.h:900
@ dnnl_normalization_flags_none
Use no normalization flags.
Definition dnnl_types.h:861
@ dnnl_use_scaleshift
Use scale and shift parameters.
Definition dnnl_types.h:887
@ dnnl_use_global_stats
Use global statistics.
Definition dnnl_types.h:874
@ batch_normalization_d
batch normalization descriptor
Definition dnnl.hpp:776
@ weights_md
weights memory descriptor desc
Definition dnnl.hpp:797
@ memory_consumption_s64
memory consumption (bytes)
Definition dnnl.hpp:741
@ shuffle_d
shuffle descriptor
Definition dnnl.hpp:766
@ deconvolution_d
deconvolution descriptor
Definition dnnl.hpp:764
@ impl_info_str
implementation name
Definition dnnl.hpp:754
@ diff_weights_md
weights gradient (diff) memory desc
Definition dnnl.hpp:799
@ workspace_md
workspace memory desc
Definition dnnl.hpp:805
@ eltwise_d
eltwise descriptor
Definition dnnl.hpp:768
@ matmul_d
matmul descriptor
Definition dnnl.hpp:788
@ rnn_d
rnn descriptor
Definition dnnl.hpp:782
@ softmax_d
softmax descriptor
Definition dnnl.hpp:770
@ num_of_outputs_s32
number of outputs expected
Definition dnnl.hpp:732
@ primitive_kind
primitive kind
Definition dnnl.hpp:727
@ dst_md
destination memory desc
Definition dnnl.hpp:801
@ scratchpad_engine
scratchpad engine
Definition dnnl.hpp:746
@ reorder_src_engine
reorder source engine
Definition dnnl.hpp:749
@ op_d
operation descriptor
Definition dnnl.hpp:760
@ layer_normalization_d
layer normalization descriptor
Definition dnnl.hpp:778
@ logsoftmax_d
logsoftmax descriptor
Definition dnnl.hpp:786
@ pooling_d
pooling descriptor
Definition dnnl.hpp:772
@ num_of_inputs_s32
number of inputs expected
Definition dnnl.hpp:730
@ diff_src_md
source gradient (diff) memory desc
Definition dnnl.hpp:795
@ src_md
source memory desc
Definition dnnl.hpp:793
@ scratchpad_md
scratchpad memory desc
Definition dnnl.hpp:807
@ reorder_dst_engine
reorder destination engine
Definition dnnl.hpp:751
@ engine
execution engine
Definition dnnl.hpp:725
@ convolution_d
convolution descriptor
Definition dnnl.hpp:762
@ time_estimate_f64
runtime estimation (seconds), unimplemented
Definition dnnl.hpp:735
@ binary_d
binary descriptor
Definition dnnl.hpp:784
@ diff_dst_md
destination gradient (diff) memory desc
Definition dnnl.hpp:803
@ exec_arg_md
memory desc of an execute argument
Definition dnnl.hpp:809
@ inner_product_d
inner product descriptor
Definition dnnl.hpp:780
@ lrn_d
lrn descriptor
Definition dnnl.hpp:774
@ resampling_d
resampling descriptor
Definition dnnl.hpp:790
@ dnnl_pooling_avg_exclude_padding
Average pooling exclude padding.
Definition dnnl_types.h:816
@ dnnl_eltwise_clip
Eltwise: clip.
Definition dnnl_types.h:794
@ dnnl_eltwise_tanh_use_dst_for_bwd
Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
Definition dnnl_types.h:802
@ dnnl_pooling_avg
Average pooling (alias for dnnl_pooling_avg_exclude_padding)
Definition dnnl_types.h:818
@ dnnl_eltwise_gelu_tanh
Eltwise: gelu.
Definition dnnl_types.h:786
@ dnnl_resampling_linear
Linear Resampling Method.
Definition dnnl_types.h:848
@ dnnl_eltwise_sqrt
Eltwise: square root.
Definition dnnl_types.h:771
@ dnnl_binary_min
Binary min.
Definition dnnl_types.h:844
@ dnnl_eltwise_abs
Eltwise: abs.
Definition dnnl_types.h:769
@ dnnl_eltwise_sqrt_use_dst_for_bwd
Eltwise: square root (dst for backward)
Definition dnnl_types.h:806
@ dnnl_eltwise_exp
Eltwise: exponent.
Definition dnnl_types.h:781
@ dnnl_eltwise_square
Eltwise: square.
Definition dnnl_types.h:767
@ dnnl_eltwise_gelu
Eltwise: tanh-based gelu (alias for dnnl_eltwise_gelu_tanh)
Definition dnnl_types.h:788
@ dnnl_convolution_winograd
Winograd convolution.
Definition dnnl_types.h:753
@ dnnl_lrn_across_channels
Local response normalization (LRN) across multiple channels.
Definition dnnl_types.h:820
@ dnnl_deconvolution_direct
Direct deconvolution.
Definition dnnl_types.h:757
@ dnnl_eltwise_relu
Eltwise: ReLU.
Definition dnnl_types.h:761
@ dnnl_convolution_auto
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition dnnl_types.h:755
@ dnnl_eltwise_swish
Eltwise: swish.
Definition dnnl_types.h:790
@ dnnl_vanilla_rnn
RNN cell.
Definition dnnl_types.h:824
@ dnnl_eltwise_gelu_erf
Eltwise: erf-based gelu.
Definition dnnl_types.h:798
@ dnnl_vanilla_lstm
LSTM cell.
Definition dnnl_types.h:826
@ dnnl_eltwise_elu
Eltwise: exponential linear unit (elu)
Definition dnnl_types.h:765
@ dnnl_vanilla_gru
GRU cell.
Definition dnnl_types.h:828
@ dnnl_lbr_gru
GRU cell with linear before reset.
Definition dnnl_types.h:836
@ dnnl_eltwise_tanh
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition dnnl_types.h:763
@ dnnl_convolution_direct
Direct convolution.
Definition dnnl_types.h:751
@ dnnl_eltwise_soft_relu
Eltwise: soft_relu.
Definition dnnl_types.h:777
@ dnnl_eltwise_log
Eltwise: natural logarithm.
Definition dnnl_types.h:792
@ dnnl_lrn_within_channel
LRN within a single channel.
Definition dnnl_types.h:822
@ dnnl_eltwise_elu_use_dst_for_bwd
Eltwise: exponential linear unit (elu) (dst for backward)
Definition dnnl_types.h:804
@ dnnl_deconvolution_winograd
Winograd deconvolution.
Definition dnnl_types.h:759
@ dnnl_eltwise_pow
Eltwise: pow.
Definition dnnl_types.h:796
@ dnnl_eltwise_relu_use_dst_for_bwd
Eltwise: ReLU (dst for backward)
Definition dnnl_types.h:800
@ dnnl_eltwise_logistic
Eltwise: logistic.
Definition dnnl_types.h:779
@ dnnl_pooling_avg_include_padding
Average pooling include padding.
Definition dnnl_types.h:814
@ dnnl_pooling_max
Max pooling.
Definition dnnl_types.h:812
@ dnnl_eltwise_logistic_use_dst_for_bwd
Eltwise: logistic (dst for backward)
Definition dnnl_types.h:808
@ dnnl_binary_add
Binary add.
Definition dnnl_types.h:838
@ dnnl_binary_mul
Binary mul.
Definition dnnl_types.h:840
@ dnnl_eltwise_exp_use_dst_for_bwd
Eltwise: exp (dst for backward)
Definition dnnl_types.h:810
@ dnnl_eltwise_bounded_relu
Eltwise: bounded_relu.
Definition dnnl_types.h:775
@ dnnl_eltwise_linear
Eltwise: linear.
Definition dnnl_types.h:773
@ dnnl_resampling_nearest
Nearest Neighbor Resampling Method.
Definition dnnl_types.h:846
@ dnnl_binary_max
Binary max.
Definition dnnl_types.h:842
@ dnnl_binary
A binary primitive.
Definition dnnl_types.h:738
@ dnnl_concat
A (out-of-place) concat primitive.
Definition dnnl_types.h:712
@ dnnl_reorder
A reorder primitive.
Definition dnnl_types.h:708
@ dnnl_convolution
A convolution primitive.
Definition dnnl_types.h:716
@ dnnl_inner_product
An inner product primitive.
Definition dnnl_types.h:732
@ dnnl_resampling
A resampling primitive.
Definition dnnl_types.h:744
@ dnnl_batch_normalization
A batch normalization primitive.
Definition dnnl_types.h:728
@ dnnl_undefined_primitive
Undefined primitive.
Definition dnnl_types.h:706
@ dnnl_sum
A sum primitive.
Definition dnnl_types.h:714
@ dnnl_layer_normalization
A layer normalization primitive.
Definition dnnl_types.h:730
@ dnnl_eltwise
An element-wise primitive.
Definition dnnl_types.h:720
@ dnnl_matmul
A matrix multiplication primitive.
Definition dnnl_types.h:742
@ dnnl_shuffle
A shuffle primitive.
Definition dnnl_types.h:710
@ dnnl_logsoftmax
A logsoftmax primitive.
Definition dnnl_types.h:740
@ dnnl_pooling
A pooling primitive.
Definition dnnl_types.h:724
@ dnnl_deconvolution
A deconvolution primitive.
Definition dnnl_types.h:718
@ dnnl_softmax
A softmax primitive.
Definition dnnl_types.h:722
@ dnnl_rnn
A rnn primitive.
Definition dnnl_types.h:734
@ dnnl_lrn
An LRN primitive.
Definition dnnl_types.h:726
@ dnnl_query_resampling_d
resampling descriptor
Definition dnnl_types.h:2045
@ dnnl_query_num_of_outputs_s32
number of outputs expected
Definition dnnl_types.h:2009
@ dnnl_query_convolution_d
convolution descriptor
Definition dnnl_types.h:2030
@ dnnl_query_weights_md
weights memory descriptor desc
Definition dnnl_types.h:2051
@ dnnl_query_src_md
source memory desc
Definition dnnl_types.h:2049
@ dnnl_query_softmax_d
softmax descriptor
Definition dnnl_types.h:2034
@ dnnl_query_binary_d
binary descriptor
Definition dnnl_types.h:2042
@ dnnl_query_workspace_md
workspace memory desc
Definition dnnl_types.h:2055
@ dnnl_query_matmul_d
matrix multiplication (matmul) descriptor
Definition dnnl_types.h:2044
@ dnnl_query_num_of_inputs_s32
number of inputs expected
Definition dnnl_types.h:2008
@ dnnl_query_op_d
op descriptor
Definition dnnl_types.h:2029
@ dnnl_query_diff_src_md
source gradient memory desc
Definition dnnl_types.h:2050
@ dnnl_query_scratchpad_md
scratchpad memory desc
Definition dnnl_types.h:2056
@ dnnl_query_shuffle_d
shuffle descriptor
Definition dnnl_types.h:2032
@ dnnl_query_memory_consumption_s64
memory consumption – extra
Definition dnnl_types.h:2012
@ dnnl_query_inner_product_d
inner product descriptor
Definition dnnl_types.h:2039
@ dnnl_query_deconvolution_d
deconvolution descriptor
Definition dnnl_types.h:2031
@ dnnl_query_primitive_kind
primitive kind
Definition dnnl_types.h:2006
@ dnnl_query_batch_normalization_d
batch normalization descriptor
Definition dnnl_types.h:2037
@ dnnl_query_impl_info_str
for creating scratchpad memory
Definition dnnl_types.h:2020
@ dnnl_query_time_estimate_f64
runtime estimation (seconds)
Definition dnnl_types.h:2011
@ dnnl_query_eltwise_d
eltwise descriptor
Definition dnnl_types.h:2033
@ dnnl_query_diff_weights_md
weights grad. memory desc
Definition dnnl_types.h:2052
@ dnnl_query_reorder_dst_engine
destination engine
Definition dnnl_types.h:2023
@ dnnl_query_reorder_src_engine
source engine
Definition dnnl_types.h:2022
@ dnnl_query_scratchpad_engine
(scratch) memory, additional to all inputs and outputs memory (bytes)
Definition dnnl_types.h:2017
@ dnnl_query_undef
no query
Definition dnnl_types.h:2003
@ dnnl_query_prop_kind
propagation kind
Definition dnnl_types.h:2025
@ dnnl_query_pooling_d
pooling descriptor
Definition dnnl_types.h:2035
@ dnnl_query_exec_arg_md
memory desc of an execute argument
Definition dnnl_types.h:2057
@ dnnl_query_engine
execution engine
Definition dnnl_types.h:2005
@ dnnl_query_rnn_d
rnn descriptor
Definition dnnl_types.h:2040
@ dnnl_query_layer_normalization_d
layer normalization descriptor
Definition dnnl_types.h:2038
@ dnnl_query_lrn_d
lrn descriptor
Definition dnnl_types.h:2036
@ dnnl_query_dst_md
destination memory desc
Definition dnnl_types.h:2053
@ dnnl_query_diff_dst_md
destination grad. memory desc
Definition dnnl_types.h:2054
@ dnnl_query_logsoftmax_d
logsoftmax descriptor
Definition dnnl_types.h:2043
@ use_scale_shift
Use scale and shift parameters.
Definition dnnl.hpp:611
@ none
Use no normalization flags.
Definition dnnl.hpp:595
@ fuse_norm_relu
Fuse normalization with ReLU.
Definition dnnl.hpp:617
@ use_global_stats
Use global statistics.
Definition dnnl.hpp:604
@ dnnl_backward_weights
Backward weights propagation.
Definition dnnl_types.h:697
@ dnnl_forward_inference
Forward data propagation (inference mode).
Definition dnnl_types.h:687
@ dnnl_backward
Backward propagation (with respect to all parameters).
Definition dnnl_types.h:693
@ dnnl_backward_data
Backward data propagation.
Definition dnnl_types.h:695
@ dnnl_prop_kind_undef
Undefined propagation type.
Definition dnnl_types.h:680
@ dnnl_forward
Forward data propagation (alias for dnnl_forward_training).
Definition dnnl_types.h:691
@ dnnl_forward_training
Forward data propagation (training mode).
Definition dnnl_types.h:683
@ dnnl_backward_bias
Backward bias propagation.
Definition dnnl_types.h:699
@ dnnl_forward_scoring
Forward data propagation (alias for dnnl_forward_inference).
Definition dnnl_types.h:689
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(dnnl_primitive_desc_t *reorder_primitive_desc, const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine, const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine, const_dnnl_primitive_attr_t attr)
Creates a primitive descriptor for a reorder primitive.
dnnl_status_t DNNL_API dnnl_resampling_backward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes a descriptor for resampling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_resampling_forward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a resampling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU backward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU forward propagation primitive.
rnn_direction
A direction of RNN primitive execution.
Definition dnnl.hpp:687
dnnl_rnn_flags_t
Flags for RNN cell.
Definition dnnl_types.h:1464
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM backward propagation primitive.
dnnl_rnn_direction_t
A direction of RNN primitive execution.
Definition dnnl_types.h:1470
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_weights_projection_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or with out recurrent project...
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for LSTM forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or without recurrent projecti...
rnn_flags
RNN cell flags.
Definition dnnl.hpp:633
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) forward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU backward propagation primitive.
@ unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
Definition dnnl.hpp:689
@ unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
Definition dnnl.hpp:691
@ bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
Definition dnnl.hpp:694
@ unidirectional
Alias for dnnl::rnn_direction::unidirectional_left2right.
Definition dnnl.hpp:699
@ bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
Definition dnnl.hpp:697
@ dnnl_rnn_flags_undef
Undefined RNN flags.
Definition dnnl_types.h:1466
@ dnnl_unidirectional
Alias for dnnl_unidirectional_left2right.
Definition dnnl_types.h:1482
@ dnnl_bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
Definition dnnl_types.h:1477
@ dnnl_bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
Definition dnnl_types.h:1480
@ dnnl_unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
Definition dnnl_types.h:1472
@ dnnl_unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
Definition dnnl_types.h:1474
@ undef
Undefined RNN flags.
Definition dnnl.hpp:635
dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable)
Configures dumping of JIT-generated code.
status set_max_cpu_isa(cpu_isa isa)
Sets the maximal ISA the library can dispatch to on the CPU.
Definition dnnl.hpp:10480
dnnl_status_t DNNL_API dnnl_set_verbose(int level)
Configures verbose output to stdout.
const version_t * version()
Returns library version information.
Definition dnnl.hpp:10438
status set_jit_dump(int enable)
Configures dumping of JIT-generated code.
Definition dnnl.hpp:10443
dnnl_cpu_isa_t
CPU instruction set flags.
Definition dnnl_types.h:2150
status set_verbose(int level)
Configures verbose output to stdout.
Definition dnnl.hpp:10433
dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa)
Sets the maximal ISA the library can dispatch to on the CPU.
dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
status set_jit_profiling_jitdumpdir(const std::string &dir)
Sets JIT dump output path.
Definition dnnl.hpp:10453
status
Status values returned by the library functions.
Definition dnnl.hpp:10415
status set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
Definition dnnl.hpp:10448
const dnnl_version_t DNNL_API * dnnl_version()
Returns library version information.
cpu_isa
CPU instruction set flags.
Definition dnnl.hpp:10458
dnnl_version_t version_t
Structure containing version information as per Semantic Versioning
Definition dnnl.hpp:10412
dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir)
Sets JIT dump output path.
@ dnnl_cpu_isa_avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
Definition dnnl_types.h:2165
@ dnnl_cpu_isa_avx
Intel Advanced Vector Extensions (Intel AVX)
Definition dnnl_types.h:2158
@ dnnl_cpu_isa_avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
Definition dnnl_types.h:2178
@ dnnl_cpu_isa_avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
Definition dnnl_types.h:2161
@ dnnl_cpu_isa_all
Any ISA (no restrictions)
Definition dnnl_types.h:2152
@ dnnl_cpu_isa_avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
Definition dnnl_types.h:2173
@ dnnl_cpu_isa_sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
Definition dnnl_types.h:2155
@ dnnl_cpu_isa_avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
Definition dnnl_types.h:2183
@ dnnl_cpu_isa_avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
Definition dnnl_types.h:2169
@ not_required
Queried element is not required for given primitive.
Definition dnnl.hpp:10429
@ invalid_arguments
The operation failed because of incorrect function arguments.
Definition dnnl.hpp:10421
@ success
The operation was successful.
Definition dnnl.hpp:10417
@ unimplemented
The operation failed because requested functionality is not implemented.
Definition dnnl.hpp:10423
@ runtime_error
Primitive or engine failed on execution.
Definition dnnl.hpp:10427
@ out_of_memory
The operation failed due to an out-of-memory condition.
Definition dnnl.hpp:10419
@ iterator_ends
Primitive iterator passed over last primitive descriptor.
Definition dnnl.hpp:10425
@ avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
Definition dnnl.hpp:10468
@ avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
Definition dnnl.hpp:10466
@ avx
Intel Advanced Vector Extensions (Intel AVX)
Definition dnnl.hpp:10464
@ all
Any ISA (no restrictions)
Definition dnnl.hpp:10460
@ avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
Definition dnnl.hpp:10472
@ avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
Definition dnnl.hpp:10470
@ sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
Definition dnnl.hpp:10462
@ avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
Definition dnnl.hpp:10474
@ avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
Definition dnnl.hpp:10476
dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle forward propagation primitive.
dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, const dnnl_memory_desc_t *diff_data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(dnnl_softmax_desc_t *softmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax forward propagation primitive.
struct dnnl_stream_attr * dnnl_stream_attr_t
An execution stream attributes handle.
Definition dnnl_types.h:2091
dnnl_stream_flags_t
Stream flags.
Definition dnnl_types.h:2068
dnnl_status_t DNNL_API dnnl_stream_create_ocl(dnnl_stream_t *stream, dnnl_engine_t engine, cl_command_queue queue)
Creates an execution stream for a given engine associated with an OpenCL command queue.
dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream)
Waits for all primitives in the execution stream to finish computations.
struct dnnl_stream * dnnl_stream_t
An execution stream handle.
Definition dnnl_types.h:2084
dnnl_status_t DNNL_API dnnl_stream_get_ocl_command_queue(dnnl_stream_t stream, cl_command_queue *queue)
Returns the OpenCL command queue associated with an execution stream.
dnnl_status_t DNNL_API dnnl_stream_create_v2(dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags, const_dnnl_stream_attr_t attr)
Creates an execution stream.
dnnl_status_t DNNL_API dnnl_stream_attr_destroy(dnnl_stream_attr_t attr)
Destroys execution stream attributes.
dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream)
Destroys an execution stream.
dnnl_status_t DNNL_API dnnl_stream_attr_create(dnnl_stream_attr_t *attr, dnnl_engine_kind_t kind)
Creates execution stream attributes for a stream that runs on an engine of a particular kind.
@ dnnl_stream_default_order
Default order execution.
Definition dnnl_types.h:2071
@ dnnl_stream_out_of_order
Out-of-order execution.
Definition dnnl_types.h:2075
@ dnnl_stream_default_flags
Default stream configuration.
Definition dnnl_types.h:2077
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(dnnl_primitive_desc_t *sum_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, const float *scales, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an (out-of-place) sum primitive.
dnnl_status_t
Status values returned by the library functions.
Definition dnnl_types.h:39
@ dnnl_iterator_ends
Primitive iterator passed over last primitive descriptor.
Definition dnnl_types.h:49
@ dnnl_runtime_error
Primitive or engine failed on execution.
Definition dnnl_types.h:51
@ dnnl_unimplemented
The operation failed because requested functionality is not implemented.
Definition dnnl_types.h:47
@ dnnl_out_of_memory
The operation failed due to an out-of-memory condition.
Definition dnnl_types.h:43
@ dnnl_success
The operation was successful.
Definition dnnl_types.h:41
@ dnnl_invalid_arguments
The operation failed because of incorrect function arguments.
Definition dnnl_types.h:45
@ dnnl_not_required
Queried element is not required for given primitive.
Definition dnnl_types.h:53
oneDNN namespace
Definition dnnl.hpp:81
Descriptor for a batch normalization backward propagation primitive.
Definition dnnl.hpp:6311
desc(prop_kind prop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for backward propagation.
Definition dnnl.hpp:6345
Primitive descriptor for a batch normalization backward propagation primitive.
Definition dnnl.hpp:6359
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:6419
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization backward propagation primitive from a C A...
Definition dnnl.hpp:6409
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:6444
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:6425
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition dnnl.hpp:6439
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6422
primitive_desc(const desc &desc, const engine &engine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition dnnl.hpp:6376
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:6416
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:6428
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition dnnl.hpp:6431
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition dnnl.hpp:6436
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition dnnl.hpp:6396
batch_normalization_backward()=default
Default constructor. Produces an empty object.
batch_normalization_backward(const primitive_desc &pd)
Constructs a batch normalization backward propagation primitive.
Definition dnnl.hpp:6453
Descriptor for a batch normalization forward propagation primitive.
Definition dnnl.hpp:6153
desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for forward propagation.
Definition dnnl.hpp:6199
Primitive descriptor for a batch normalization forward propagation primitive.
Definition dnnl.hpp:6212
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:6260
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:6266
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition dnnl.hpp:6273
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:6269
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition dnnl.hpp:6226
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition dnnl.hpp:6277
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization forward propagation primitive from a C AP...
Definition dnnl.hpp:6253
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition dnnl.hpp:6242
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6263
batch_normalization_forward()=default
Default constructor. Produces an empty object.
batch_normalization_forward(const primitive_desc &pd)
Constructs a batch normalization forward propagation primitive.
Definition dnnl.hpp:6305
Descriptor for an elementwise binary operator primitive.
Definition dnnl.hpp:9885
desc()=default
Default constructor. Produces an empty object.
dnnl_binary_desc_t data
Underlying C operation descriptor.
Definition dnnl.hpp:9887
desc(algorithm algorithm, const memory::desc &src0, const memory::desc &src1, const memory::desc &dst)
Constructs a descriptor for an elementwise binary operator primitive.
Definition dnnl.hpp:9906
Primitive descriptor for an elementwise binary operator primitive.
Definition dnnl.hpp:9917
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition dnnl.hpp:9930
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition dnnl.hpp:9958
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition dnnl.hpp:9945
memory::desc src0_desc() const
Returns the memory descriptor for source #0.
Definition dnnl.hpp:9961
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a binary primitive from a C API primitive descriptor that must ...
Definition dnnl.hpp:9954
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:9967
memory::desc src1_desc() const
Returns the memory descriptor for source #1.
Definition dnnl.hpp:9964
binary()=default
Default constructor. Produces an empty object.
binary(const primitive_desc &pd)
Constructs an elementwise binary operation primitive.
Definition dnnl.hpp:9976
Primitive descriptor for a concat primitive.
Definition dnnl.hpp:3256
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:3334
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for concat primitive from a C API primitive descriptor which must h...
Definition dnnl.hpp:3327
primitive_desc_base()=default
Default constructor. Produces an empty object.
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(int concat_dimension, const std::vector< memory::desc > &srcs, const engine &engine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition dnnl.hpp:3308
primitive_desc(const memory::desc &dst, int concat_dimension, const std::vector< memory::desc > &srcs, const engine &engine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition dnnl.hpp:3281
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition dnnl.hpp:3331
concat()=default
Default constructor. Produces an empty object.
concat(const primitive_desc &pd)
Constructs a concatenation primitive.
Definition dnnl.hpp:3342
Descriptor for a convolution backward propagation primitive.
Definition dnnl.hpp:3821
desc(algorithm algorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution backward propagation primitive.
Definition dnnl.hpp:3850
desc(algorithm algorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for dilated convolution backward propagation primitive.
Definition dnnl.hpp:3894
Primitive descriptor for a convolution backward propagation primitive.
Definition dnnl.hpp:3915
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:3973
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:3976
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution backward propagation primitive from a C API primi...
Definition dnnl.hpp:3965
primitive_desc(const desc &desc, const engine &engine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition dnnl.hpp:3932
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:3970
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition dnnl.hpp:3952
convolution_backward_data()=default
Default constructor. Produces an empty object.
convolution_backward_data(const primitive_desc &pd)
Constructs a convolution backward propagation primitive.
Definition dnnl.hpp:3985
Descriptor for a convolution weights gradient primitive.
Definition dnnl.hpp:3991
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive with bias.
Definition dnnl.hpp:4114
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive without bias.
Definition dnnl.hpp:4162
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive with bias.
Definition dnnl.hpp:4023
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive without bias.
Definition dnnl.hpp:4067
Primitive descriptor for a convolution weights gradient primitive.
Definition dnnl.hpp:4183
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition dnnl.hpp:4250
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition dnnl.hpp:4239
primitive_desc(const desc &desc, const engine &engine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition dnnl.hpp:4199
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution weights gradient primitive from a C API primitive...
Definition dnnl.hpp:4231
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:4236
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition dnnl.hpp:4218
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:4244
convolution_backward_weights()=default
Default constructor. Produces an empty object.
convolution_backward_weights(const primitive_desc &pd)
Constructs a convolution weights gradient primitive.
Definition dnnl.hpp:4261
Descriptor for a convolution forward propagation primitive.
Definition dnnl.hpp:3542
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive without bias.
Definition dnnl.hpp:3624
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive with bias.
Definition dnnl.hpp:3675
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive with bias.
Definition dnnl.hpp:3577
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive without bias.
Definition dnnl.hpp:3725
Primitive descriptor for a convolution forward propagation primitive.
Definition dnnl.hpp:3746
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:3793
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition dnnl.hpp:3776
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution forward propagation primitive from a C API primit...
Definition dnnl.hpp:3787
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition dnnl.hpp:3805
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition dnnl.hpp:3760
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:3796
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:3799
convolution_forward(const primitive_desc &pd)
Constructs a convolution forward propagation primitive.
Definition dnnl.hpp:3814
convolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for a deconvolution backward propagation primitive.
Definition dnnl.hpp:4548
desc(algorithm algorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution backward propagation primitive.
Definition dnnl.hpp:4576
desc(algorithm algorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution backward propagation primitive.
Definition dnnl.hpp:4619
Primitive descriptor for a deconvolution backward propagation primitive.
Definition dnnl.hpp:4640
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:4698
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:4701
primitive_desc(const desc &desc, const engine &engine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition dnnl.hpp:4657
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:4695
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition dnnl.hpp:4677
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution backward propagation primitive from a C API pri...
Definition dnnl.hpp:4690
deconvolution_backward_data()=default
Default constructor. Produces an empty object.
deconvolution_backward_data(const primitive_desc &pd)
Constructs a deconvolution backward propagation primitive.
Definition dnnl.hpp:4710
Descriptor for a deconvolution weights gradient primitive.
Definition dnnl.hpp:4716
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive with bias.
Definition dnnl.hpp:4836
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive without bias.
Definition dnnl.hpp:4883
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive with bias.
Definition dnnl.hpp:4747
desc(algorithm algorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive without bias.
Definition dnnl.hpp:4790
Primitive descriptor for a deconvolution weights gradient primitive.
Definition dnnl.hpp:4904
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:4959
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition dnnl.hpp:4941
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:4967
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution weights gradient primitive from a C API primiti...
Definition dnnl.hpp:4954
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition dnnl.hpp:4962
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition dnnl.hpp:4970
primitive_desc(const desc &desc, const engine &engine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition dnnl.hpp:4921
primitive_desc()=default
Default constructor. Produces an empty object.
deconvolution_backward_weights()=default
Default constructor. Produces an empty object.
deconvolution_backward_weights(const primitive_desc &pd)
Constructs a deconvolution weights gradient primitive.
Definition dnnl.hpp:4981
Descriptor for a deconvolution forward propagation primitive.
Definition dnnl.hpp:4277
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive with bias.
Definition dnnl.hpp:4311
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive without bias.
Definition dnnl.hpp:4357
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive with bias.
Definition dnnl.hpp:4407
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive without bias.
Definition dnnl.hpp:4456
Primitive descriptor for a deconvolution forward propagation primitive.
Definition dnnl.hpp:4477
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition dnnl.hpp:4491
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution forward propagation primitive from a C API prim...
Definition dnnl.hpp:4518
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:4530
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:4524
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition dnnl.hpp:4507
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition dnnl.hpp:4533
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:4527
deconvolution_forward(const primitive_desc &pd)
Constructs a deconvolution forward propagation primitive.
Definition dnnl.hpp:4542
deconvolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for an elementwise backward propagation primitive.
Definition dnnl.hpp:5587
desc(algorithm algorithm, const memory::desc &diff_data_desc, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise backward propagation primitive.
Definition dnnl.hpp:5608
Primitive descriptor for eltwise backward propagation.
Definition dnnl.hpp:5621
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition dnnl.hpp:5658
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:5679
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:5676
primitive_desc(const desc &desc, const engine &engine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition dnnl.hpp:5638
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:5682
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise backward propagation primitive from a C API primitiv...
Definition dnnl.hpp:5671
eltwise_backward()=default
Default constructor. Produces an empty object.
eltwise_backward(const primitive_desc &pd)
Constructs an eltwise backward propagation primitive.
Definition dnnl.hpp:5691
Descriptor for an elementwise forward propagation primitive.
Definition dnnl.hpp:5488
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise forward propagation primitive.
Definition dnnl.hpp:5509
Primitive descriptor for an elementwise forward propagation primitive.
Definition dnnl.hpp:5522
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:5572
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition dnnl.hpp:5536
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:5569
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition dnnl.hpp:5552
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise forward propagation primitive from a C API primitive...
Definition dnnl.hpp:5563
eltwise_forward(const primitive_desc &pd)
Constructs an eltwise forward propagation primitive.
Definition dnnl.hpp:5581
eltwise_forward()=default
Default constructor. Produces an empty object.
An execution engine.
Definition dnnl.hpp:844
static engine query(const primitive_desc &pd)
Returns the engine of a primitive descriptor.
Definition dnnl.hpp:949
kind
Kinds of engines.
Definition dnnl.hpp:849
@ gpu
GPU engine.
Definition dnnl.hpp:855
@ any
An unspecified engine.
Definition dnnl.hpp:851
@ cpu
CPU engine.
Definition dnnl.hpp:853
static size_t get_count(kind kind)
Returns the number of engines of a certain kind.
Definition dnnl.hpp:868
engine(kind kind, cl_device_id device, cl_context context)
Constructs an engine from OpenCL device and context objects.
Definition dnnl.hpp:892
engine()=default
Constructs an empty engine.
handle()=default
Constructs an empty handle object.
cl_device_id get_ocl_device() const
Returns the OpenCL device associated with the engine.
Definition dnnl.hpp:935
engine(const handle< dnnl_primitive_desc_t > &pd)
Constructs an engine based on a primitive from the primitive descriptor pd by querying its engine.
Definition dnnl.hpp:905
engine(kind kind, size_t index)
Constructs an engine.
Definition dnnl.hpp:877
cl_context get_ocl_context() const
Returns the OpenCL context associated with the engine.
Definition dnnl.hpp:926
kind get_kind() const
Returns the kind of the engine.
Definition dnnl.hpp:916
oneDNN exception class.
Definition dnnl.hpp:91
error(dnnl_status_t status, const char *message)
Constructs an instance of an exception class.
Definition dnnl.hpp:99
static void wrap_c_api(dnnl_status_t status, const char *message)
A convenience function for wrapping calls to C API functions.
Definition dnnl.hpp:110
const char * what() const noexcept override
Returns the explanatory string.
Definition dnnl.hpp:103
Descriptor for a GRU backward propagation primitive.
Definition dnnl.hpp:9064
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU backward propagation primitive.
Definition dnnl.hpp:9133
Primitive descriptor for a GRU backward propagation primitive.
Definition dnnl.hpp:9167
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition dnnl.hpp:9269
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:9241
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:9228
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:9225
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition dnnl.hpp:9274
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:9233
primitive_desc(const desc &desc, const engine &engine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition dnnl.hpp:9183
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:9238
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition dnnl.hpp:9202
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition dnnl.hpp:9284
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition dnnl.hpp:9279
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU backward propagation primitive from a C API primitive des...
Definition dnnl.hpp:9215
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:9220
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:9249
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition dnnl.hpp:9254
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition dnnl.hpp:9259
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition dnnl.hpp:9264
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:9246
gru_backward()=default
Default constructor. Produces an empty object.
gru_backward(const primitive_desc &pd)
Constructs a GRU backward propagation primitive.
Definition dnnl.hpp:9295
Descriptor for a GRU forward propagation primitive.
Definition dnnl.hpp:8902
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU forward propagation primitive.
Definition dnnl.hpp:8950
Primitive descriptor GRU forward propagation primitive.
Definition dnnl.hpp:8973
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:9031
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:9018
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:9039
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition dnnl.hpp:9001
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:9026
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:9036
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:9044
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:9047
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:9023
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition dnnl.hpp:8986
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU forward propagation primitive from a C API primitive desc...
Definition dnnl.hpp:9012
gru_forward(const primitive_desc &pd)
Constructs a GRU forward propagation primitive.
Definition dnnl.hpp:9058
gru_forward()=default
Default constructor. Produces an empty object.
A class that provides the destructor for a oneDNN C API handle.
Definition dnnl.hpp:127
oneDNN C API handle wrapper class.
Definition dnnl.hpp:143
handle(const handle< T, traits > &)=default
Copy constructor.
bool operator==(const handle< T, traits > &other) const
Equality operator.
Definition dnnl.hpp:217
handle< T, traits > & operator=(const handle< T, traits > &)=default
Assignment operator.
bool operator!=(const handle &other) const
Inequality operator.
Definition dnnl.hpp:227
T get(bool allow_empty=false) const
Returns the underlying C API handle.
Definition dnnl.hpp:192
handle()=default
Constructs an empty handle object.
void reset(T t, bool weak=false)
Resets the handle wrapper objects to wrap a new C API handle.
Definition dnnl.hpp:183
handle(T t, bool weak=false)
Constructs a handle wrapper object from a C API handle.
Definition dnnl.hpp:176
handle(handle< T, traits > &&)=default
Move constructor.
handle< T, traits > & operator=(handle< T, traits > &&)=default
Move assignment operator.
Descriptor for an inner product backward propagation primitive.
Definition dnnl.hpp:7010
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product backward propagation primitive.
Definition dnnl.hpp:7030
Primitive descriptor for an inner product backward propagation primitive.
Definition dnnl.hpp:7043
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition dnnl.hpp:7080
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:7104
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:7101
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product backward propagation primitive from a C API pr...
Definition dnnl.hpp:7093
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:7098
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &desc, const engine &engine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition dnnl.hpp:7060
inner_product_backward_data(const primitive_desc &pd)
Constructs an inner product backward propagation primitive.
Definition dnnl.hpp:7113
inner_product_backward_data()=default
Default constructor. Produces an empty object.
Descriptor for an inner product weights gradient primitive.
Definition dnnl.hpp:7119
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive with bias.
Definition dnnl.hpp:7141
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive without bias.
Definition dnnl.hpp:7170
Primitive descriptor for an inner product weights gradient primitive.
Definition dnnl.hpp:7183
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:7238
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition dnnl.hpp:7220
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition dnnl.hpp:7241
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:7246
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product weights update primitive from a C API primitiv...
Definition dnnl.hpp:7233
primitive_desc(const desc &desc, const engine &engine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition dnnl.hpp:7200
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition dnnl.hpp:7249
inner_product_backward_weights(const primitive_desc &pd)
Constructs an inner product weights gradient primitive.
Definition dnnl.hpp:7260
inner_product_backward_weights()=default
Default constructor. Produces an empty object.
Descriptor for an inner product forward propagation primitive.
Definition dnnl.hpp:6870
desc(prop_kind prop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive with bias.
Definition dnnl.hpp:6895
desc(prop_kind prop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive without bias.
Definition dnnl.hpp:6926
Primitive descriptor for an inner product forward propagation primitive.
Definition dnnl.hpp:6939
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product forward propagation primitive from a C API pri...
Definition dnnl.hpp:6980
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6992
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition dnnl.hpp:6969
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:6989
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition dnnl.hpp:6953
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition dnnl.hpp:6995
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:6986
inner_product_forward(const primitive_desc &pd)
Constructs an inner product forward propagation primitive.
Definition dnnl.hpp:7004
inner_product_forward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization backward propagation primitive.
Definition dnnl.hpp:6674
desc(prop_kind prop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition dnnl.hpp:6746
desc(prop_kind prop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition dnnl.hpp:6706
Primitive descriptor for a layer normalization backward propagation primitive.
Definition dnnl.hpp:6760
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:6826
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition dnnl.hpp:6837
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition dnnl.hpp:6797
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:6845
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization backward propagation primitive from a C A...
Definition dnnl.hpp:6810
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:6829
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6823
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition dnnl.hpp:6840
primitive_desc(const desc &desc, const engine &engine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition dnnl.hpp:6777
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:6820
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition dnnl.hpp:6832
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:6817
layer_normalization_backward(const primitive_desc &pd)
Constructs a layer normalization backward propagation primitive.
Definition dnnl.hpp:6854
layer_normalization_backward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization forward propagation primitive.
Definition dnnl.hpp:6483
desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition dnnl.hpp:6564
desc(prop_kind prop_kind, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition dnnl.hpp:6520
Primitive descriptor for a layer normalization forward propagation primitive.
Definition dnnl.hpp:6577
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6628
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:6625
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:6634
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization forward propagation primitive from a C AP...
Definition dnnl.hpp:6618
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition dnnl.hpp:6640
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:6631
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition dnnl.hpp:6607
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition dnnl.hpp:6591
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition dnnl.hpp:6637
layer_normalization_forward()=default
Default constructor. Produces an empty object.
layer_normalization_forward(const primitive_desc &pd)
Constructs a layer normalization forward propagation primitive.
Definition dnnl.hpp:6668
Descriptor for a LBR GRU backward propagation primitive.
Definition dnnl.hpp:9465
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU backward propagation primitive.
Definition dnnl.hpp:9535
Primitive descriptor for an LBR GRU backward propagation primitive.
Definition dnnl.hpp:9569
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition dnnl.hpp:9606
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:9632
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition dnnl.hpp:9668
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition dnnl.hpp:9688
primitive_desc(const desc &desc, const engine &engine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition dnnl.hpp:9586
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition dnnl.hpp:9678
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:9650
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:9637
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:9629
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition dnnl.hpp:9663
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU backward propagation primitive from a C API primitive...
Definition dnnl.hpp:9619
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:9653
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:9642
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:9645
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:9624
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition dnnl.hpp:9673
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition dnnl.hpp:9683
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition dnnl.hpp:9658
lbr_gru_backward(const primitive_desc &pd)
Constructs an LBR GRU backward propagation primitive.
Definition dnnl.hpp:9699
lbr_gru_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LBR GRU forward propagation primitive.
Definition dnnl.hpp:9301
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU forward propagation primitive.
Definition dnnl.hpp:9349
Primitive descriptor for an LBR GRU forward propagation primitive.
Definition dnnl.hpp:9372
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:9445
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition dnnl.hpp:9386
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:9424
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:9440
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:9448
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive from a C API primitive ...
Definition dnnl.hpp:9413
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition dnnl.hpp:9402
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:9437
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:9419
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:9432
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:9427
lbr_gru_forward()=default
Default constructor. Produces an empty object.
lbr_gru_forward(const primitive_desc &pd)
Constructs an LBR GRU forward propagation primitive.
Definition dnnl.hpp:9459
Descriptor for a logsoftmax backward propagation primitive.
Definition dnnl.hpp:6020
desc()=default
Default constructor. Produces an empty object.
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax backward propagation primitive.
Definition dnnl.hpp:6040
Primitive descriptor for a logsoftmax backward propagation primitive.
Definition dnnl.hpp:6051
primitive_desc(const desc &desc, const engine &engine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition dnnl.hpp:6068
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6110
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6116
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:6113
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive from a C API primit...
Definition dnnl.hpp:6101
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition dnnl.hpp:6088
logsoftmax_backward(const primitive_desc &pd)
Constructs a logsoftmax backward propagation primitive.
Definition dnnl.hpp:6125
logsoftmax_backward()=default
Default constructor. Produces an empty object.
Descriptor for a logsoftmax forward propagation primitive.
Definition dnnl.hpp:5920
desc()=default
Default constructor. Produces an empty object.
desc(prop_kind prop_kind, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax forward propagation primitive.
Definition dnnl.hpp:5940
Primitive descriptor for a logsoftmax forward propagation primitive.
Definition dnnl.hpp:5951
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:6005
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:6002
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive from a C API primiti...
Definition dnnl.hpp:5992
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition dnnl.hpp:5981
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition dnnl.hpp:5965
logsoftmax_forward()=default
Default constructor. Produces an empty object.
logsoftmax_forward(const primitive_desc &pd)
Constructs a logsoftmax forward propagation primitive.
Definition dnnl.hpp:6014
Descriptor for an LRN backward propagation primitive.
Definition dnnl.hpp:5104
desc(algorithm algorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for an LRN backward propagation primitive.
Definition dnnl.hpp:5130
Primitive descriptor for an LRN backward propagation primitive.
Definition dnnl.hpp:5143
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition dnnl.hpp:5178
primitive_desc(const desc &desc, const engine &engine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition dnnl.hpp:5159
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN backward propagation primitive from a C API primitive de...
Definition dnnl.hpp:5191
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:5199
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:5202
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:5196
primitive_desc()=default
Default constructor. Produces an empty object.
lrn_backward(const primitive_desc &pd)
Constructs an LRN backward propagation primitive.
Definition dnnl.hpp:5211
lrn_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LRN forward propagation primitive.
Definition dnnl.hpp:4998
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for a LRN forward propagation primitive.
Definition dnnl.hpp:5025
Primitive descriptor for an LRN forward propagation primitive.
Definition dnnl.hpp:5038
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:5083
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:5086
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition dnnl.hpp:5066
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition dnnl.hpp:5051
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:5089
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN forward propagation primitive from a C API primitive des...
Definition dnnl.hpp:5077
lrn_forward()=default
Default constructor. Produces an empty object.
lrn_forward(const primitive_desc &pd)
Constructs an LRN forward propagation primitive.
Definition dnnl.hpp:5098
Descriptor for an LSTM backward propagation primitive.
Definition dnnl.hpp:8283
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole) descriptor for backward propagation using prop_kind,...
Definition dnnl.hpp:8561
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM descriptor for backward propagation using prop_kind, direction,...
Definition dnnl.hpp:8687
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_weights_projection_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole and with or without projection) descriptor for backward ...
Definition dnnl.hpp:8413
Primitive descriptor for LSTM backward propagation.
Definition dnnl.hpp:8728
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:8799
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition dnnl.hpp:8880
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition dnnl.hpp:8865
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition dnnl.hpp:8804
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM backward propagation primitive from a C API primitive d...
Definition dnnl.hpp:8776
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition dnnl.hpp:8860
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:8825
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:8781
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:8822
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition dnnl.hpp:8835
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:8786
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition dnnl.hpp:8855
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition dnnl.hpp:8809
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition dnnl.hpp:8763
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition dnnl.hpp:8870
primitive_desc(const desc &desc, const engine &engine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition dnnl.hpp:8744
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:8814
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:8789
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:8817
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition dnnl.hpp:8885
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition dnnl.hpp:8840
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition dnnl.hpp:8875
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:8794
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition dnnl.hpp:8850
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:8830
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition dnnl.hpp:8845
lstm_backward()=default
Default constructor. Produces an empty object.
lstm_backward(const primitive_desc &pd)
Constructs an LSTM backward propagation primitive.
Definition dnnl.hpp:8896
Descriptor for an LSTM forward propagation primitive.
Definition dnnl.hpp:7907
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole and with or without projection) forward...
Definition dnnl.hpp:7983
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM forward propagation primitive.
Definition dnnl.hpp:8146
desc(prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole) forward propagation primitive.
Definition dnnl.hpp:8070
Primitive descriptor for an LSTM forward propagation primitive.
Definition dnnl.hpp:8172
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition dnnl.hpp:8185
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:8258
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition dnnl.hpp:8240
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:8235
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:8253
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:8266
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM forward propagation primitive from a C API primitive de...
Definition dnnl.hpp:8211
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:8261
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:8230
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition dnnl.hpp:8245
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:8225
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition dnnl.hpp:8200
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:8222
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:8250
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:8217
lstm_forward(const primitive_desc &pd)
Constructs an LSTM forward propagation primitive.
Definition dnnl.hpp:8277
lstm_forward()=default
Default constructor. Produces an empty object.
Descriptor for a matmul primitive.
Definition dnnl.hpp:9994
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition dnnl.hpp:10009
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition dnnl.hpp:10031
Primitive descriptor for a matmul primitive.
Definition dnnl.hpp:10041
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:10083
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition dnnl.hpp:10053
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a matmul primitive from a C API primitive descriptor that must ...
Definition dnnl.hpp:10076
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition dnnl.hpp:10088
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:10080
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:10093
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition dnnl.hpp:10067
matmul(const primitive_desc &pd)
Constructs a matmul primitive.
Definition dnnl.hpp:10101
matmul()=default
Default constructor. Produces an empty object.
A memory descriptor.
Definition dnnl.hpp:1729
desc()
Constructs a zero (empty) memory descriptor.
Definition dnnl.hpp:1736
desc submemory_desc(const memory::dims &dims, const memory::dims &offsets, bool allow_empty=false) const
Constructs a memory descriptor for a region inside an area described by this memory descriptor.
Definition dnnl.hpp:1811
bool operator!=(const desc &other) const
An inequality operator.
Definition dnnl.hpp:1963
desc permute_axes(const std::vector< int > &permutation, bool allow_empty=false) const
Constructs a memory descriptor by permuting axes in an existing one.
Definition dnnl.hpp:1914
bool operator==(const desc &other) const
An equality operator.
Definition dnnl.hpp:1955
bool is_zero() const
Checks whether the memory descriptor is zero (empty).
Definition dnnl.hpp:1949
desc(const memory::dims &dims, data_type data_type, const memory::dims &strides, bool allow_empty=false)
Constructs a memory descriptor by strides.
Definition dnnl.hpp:1781
memory::dims dims() const
Returns dimensions of the memory descriptor.
Definition dnnl.hpp:1930
memory::data_type data_type() const
Returns the data type of the memory descriptor.
Definition dnnl.hpp:1936
size_t get_size() const
Returns size of the memory descriptor in bytes.
Definition dnnl.hpp:1944
desc reshape(const memory::dims &dims, bool allow_empty=false) const
Constructs a memory descriptor by reshaping an existing one.
Definition dnnl.hpp:1867
desc(const dnnl_memory_desc_t &data)
Constructs a memory descriptor from a C API data structure.
Definition dnnl.hpp:1798
desc(const memory::dims &dims, data_type data_type, format_tag format_tag, bool allow_empty=false)
Constructs a memory descriptor.
Definition dnnl.hpp:1753
dnnl_memory_desc_t data
The underlying C API data structure.
Definition dnnl.hpp:1732
Memory object.
Definition dnnl.hpp:1188
void unmap_data(void *mapped_ptr) const
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
Definition dnnl.hpp:2118
cl_mem get_ocl_mem_object() const
Returns the OpenCL memory object associated with the memory.
Definition dnnl.hpp:2125
static void validate_dims(const std::vector< T > &v, int min_size=0)
Helper function that validates that an std::vector of dimensions can be safely converted to the C API...
Definition dnnl.hpp:1202
T * map_data() const
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
Definition dnnl.hpp:2102
dnnl_dim_t dim
Integer type for representing dimension sizes and indices.
Definition dnnl.hpp:1190
void set_ocl_mem_object(cl_mem mem_object)
Sets the OpenCL memory object mem_object associated with the memory.
Definition dnnl.hpp:2139
memory(const desc &md, const engine &engine, void *handle)
Constructs a memory object.
Definition dnnl.hpp:1992
format_tag
Memory format tag specification.
Definition dnnl.hpp:1282
@ ba
permuted 2D tensor
Definition dnnl.hpp:1295
@ bcda
permuted 4D tensor
Definition dnnl.hpp:1317
@ a
plain 1D tensor
Definition dnnl.hpp:1290
@ oihw
4D CNN weights tensor; an alias for dnnl::memory::format_tag::abcd
Definition dnnl.hpp:1382
@ wigo
4D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::dcab
Definition dnnl.hpp:1403
@ ab
plain 2D tensor
Definition dnnl.hpp:1293
@ nc
2D CNN activations tensor; an alias for dnnl::memory::format_tag::ab
Definition dnnl.hpp:1347
@ goiw
4D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::abcd
Definition dnnl.hpp:1401
@ ldio
4D LSTM projection tensor in the format (num_layers, num_directions, num_channels_in_hidden_state,...
Definition dnnl.hpp:1440
@ nt
2D RNN statistics tensor; an alias for dnnl::memory::format_tag::ba
Definition dnnl.hpp:1353
@ idhwo
5D CNN weights tensor; an alias for dnnl::memory::format_tag::bcdea
Definition dnnl.hpp:1398
@ cba
permuted 3D tensor
Definition dnnl.hpp:1306
@ oiw
3D CNN weights tensor; an alias for dnnl::memory::format_tag::abc
Definition dnnl.hpp:1374
@ goihw
5D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::abcde
Definition dnnl.hpp:1405
@ aBcd8b
4D tensor blocked by 2nd dimension with block size 8
Definition dnnl.hpp:1488
@ ldgoi
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels,...
Definition dnnl.hpp:1437
@ ldigo
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates,...
Definition dnnl.hpp:1430
@ giodhw
6D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::abcdef
Definition dnnl.hpp:1413
@ owi
3D CNN weights tensor; an alias for dnnl::memory::format_tag::acb
Definition dnnl.hpp:1376
@ ihwo
4D CNN weights tensor; an alias for dnnl::memory::format_tag::bcda
Definition dnnl.hpp:1388
@ bacd
permuted 4D tensor
Definition dnnl.hpp:1315
@ wio
3D CNN weights tensor; an alias for dnnl::memory::format_tag::cba
Definition dnnl.hpp:1378
@ abdc
permuted 4D tensor
Definition dnnl.hpp:1311
@ giohw
5D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::acbde
Definition dnnl.hpp:1409
@ acb
permuted 3D tensor
Definition dnnl.hpp:1300
@ acbdef
plain 6D tensor
Definition dnnl.hpp:1340
@ ntc
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition dnnl.hpp:1420
@ bcdea
permuted 5D tensor
Definition dnnl.hpp:1332
@ nhwc
4D CNN activations tensor; an alias for dnnl::memory::format_tag::acdb
Definition dnnl.hpp:1361
@ acdb
permuted 4D tensor
Definition dnnl.hpp:1313
@ ldoi
4D LSTM projection tensor in the format (num_layers, num_directions, num_channels_in_recurrent_projec...
Definition dnnl.hpp:1443
@ bac
permuted 3D tensor
Definition dnnl.hpp:1302
@ abdec
permuted 5D tensor
Definition dnnl.hpp:1326
@ iwo
3D CNN weights tensor; an alias for dnnl::memory::format_tag::bca
Definition dnnl.hpp:1380
@ cn
2D CNN activations tensor; an alias for dnnl::memory::format_tag::ba
Definition dnnl.hpp:1349
@ abc
plain 3D tensor
Definition dnnl.hpp:1298
@ dcab
permuted 4D tensor
Definition dnnl.hpp:1321
@ oidhw
5D CNN weights tensor; an alias for dnnl::memory::format_tag::abcde
Definition dnnl.hpp:1392
@ x
1D tensor; an alias for dnnl::memory::format_tag::a
Definition dnnl.hpp:1345
@ oi
2D CNN weights tensor; an alias for dnnl::memory::format_tag::ab
Definition dnnl.hpp:1370
@ goidhw
6D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::abcdef
Definition dnnl.hpp:1411
@ tn
2D RNN statistics tensor; an alias for dnnl::memory::format_tag::ab
Definition dnnl.hpp:1351
@ abcde
plain 5D tensor
Definition dnnl.hpp:1324
@ dhwigo
6D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::defcab
Definition dnnl.hpp:1415
@ ldnc
4D RNN states tensor in the format (num_layers, num_directions, batch, state channels).
Definition dnnl.hpp:1423
@ bca
permuted 3D tensor
Definition dnnl.hpp:1304
@ ldgo
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition dnnl.hpp:1450
@ ohwi
4D CNN weights tensor; an alias for dnnl::memory::format_tag::acdb
Definition dnnl.hpp:1386
@ decab
permuted 5D tensor
Definition dnnl.hpp:1336
@ ncw
3D CNN activations tensor; an alias for dnnl::memory::format_tag::abc
Definition dnnl.hpp:1355
@ ABcd8b8a
4D tensor blocked by 1st and 2nd dimension with block size 8
Definition dnnl.hpp:1492
@ odhwi
5D CNN weights tensor; an alias for dnnl::memory::format_tag::acdeb
Definition dnnl.hpp:1396
@ iohw
4D CNN weights tensor; an alias for dnnl::memory::format_tag::bacd
Definition dnnl.hpp:1390
@ tnc
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition dnnl.hpp:1418
@ defcab
plain 6D tensor
Definition dnnl.hpp:1342
@ nwc
3D CNN activations tensor; an alias for dnnl::memory::format_tag::acb
Definition dnnl.hpp:1357
@ ndhwc
5D CNN activations tensor; an alias for dnnl::memory::format_tag::acdeb
Definition dnnl.hpp:1367
@ hwio
4D CNN weights tensor; an alias for dnnl::memory::format_tag::cdba
Definition dnnl.hpp:1384
@ nchw
4D CNN activations tensor; an alias for dnnl::memory::format_tag::abcd
Definition dnnl.hpp:1359
@ acbde
permuted 5D tensor
Definition dnnl.hpp:1328
@ abcd
plain 4D tensor
Definition dnnl.hpp:1309
@ ncdhw
5D CNN activations tensor; an alias for dnnl::memory::format_tag::abcde
Definition dnnl.hpp:1365
@ abcdef
plain 6D tensor
Definition dnnl.hpp:1338
@ dhwio
5D CNN weights tensor; an alias for dnnl::memory::format_tag::cdeba
Definition dnnl.hpp:1394
@ acdeb
permuted 5D tensor
Definition dnnl.hpp:1330
@ io
2D CNN weights tensor; an alias for dnnl::memory::format_tag::ba
Definition dnnl.hpp:1372
@ cdeba
permuted 5D tensor
Definition dnnl.hpp:1334
@ chwn
4D CNN activations tensor; an alias for dnnl::memory::format_tag::bcda
Definition dnnl.hpp:1363
@ hwigo
5D CNN weights tensor with groups; an alias for dnnl::memory::format_tag::decab
Definition dnnl.hpp:1407
@ cdba
permuted 4D tensor
Definition dnnl.hpp:1319
data_type
Data type specification.
Definition dnnl.hpp:1208
@ u8
8-bit unsigned integer.
Definition dnnl.hpp:1222
@ s8
8-bit signed integer.
Definition dnnl.hpp:1220
@ f32
32-bit/single-precision floating point.
Definition dnnl.hpp:1216
@ f16
16-bit/half-precision floating point.
Definition dnnl.hpp:1212
@ s32
32-bit signed integer.
Definition dnnl.hpp:1218
@ undef
Undefined data type (used for empty memory descriptors).
Definition dnnl.hpp:1210
@ bf16
non-standard 16-bit floating point with 7-bit mantissa.
Definition dnnl.hpp:1214
engine get_engine() const
Returns the associated engine.
Definition dnnl.hpp:2018
void set_data_handle(void *handle, const stream &stream) const
Sets data handle.
Definition dnnl.hpp:2061
memory(const desc &md, const engine &engine)
Constructs a memory object.
Definition dnnl.hpp:2006
format_kind
Memory format kind.
Definition dnnl.hpp:1226
@ any
Unspecified format kind.
Definition dnnl.hpp:1231
@ blocked
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition dnnl.hpp:1235
@ wino
Weights format used in 8bit Winograd convolution.
Definition dnnl.hpp:1237
@ packed
Packed weights format used in RNN.
Definition dnnl.hpp:1239
void set_data_handle(void *handle) const
Sets data handle.
Definition dnnl.hpp:2075
desc get_desc() const
Returns the associated memory descriptor.
Definition dnnl.hpp:2010
void * get_data_handle() const
Returns the underlying memory buffer.
Definition dnnl.hpp:2028
std::vector< dim > dims
Vector of dimensions.
Definition dnnl.hpp:1193
Descriptor for a pooling backward propagation primitive.
Definition dnnl.hpp:5344
desc(algorithm algorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling backward propagation primitive.
Definition dnnl.hpp:5372
Primitive descriptor for a pooling backward propagation primitive.
Definition dnnl.hpp:5391
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition dnnl.hpp:5426
primitive_desc(const desc &desc, const engine &engine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition dnnl.hpp:5407
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:5447
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:5450
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:5444
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling backward propagation primitive from a C API primitive...
Definition dnnl.hpp:5439
pooling_backward()=default
Default constructor. Produces an empty object.
pooling_backward(const primitive_desc &pd)
Constructs a pooling backward propagation primitive.
Definition dnnl.hpp:5459
Descriptor for a pooling forward propagation primitive.
Definition dnnl.hpp:5227
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling forward propagation primitive.
Definition dnnl.hpp:5259
Primitive descriptor for a pooling forward propagation primitive.
Definition dnnl.hpp:5278
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition dnnl.hpp:5306
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:5326
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:5323
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling forward propagation primitive from a C API primitive ...
Definition dnnl.hpp:5317
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition dnnl.hpp:5291
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:5329
pooling_forward(const primitive_desc &pd)
Constructs a pooling forward propagation primitive.
Definition dnnl.hpp:5338
pooling_forward()=default
Default constructor. Produces an empty object.
Post-ops.
Definition dnnl.hpp:2205
void append_sum(float scale=1.)
Appends an accumulation (sum) post-op.
Definition dnnl.hpp:2251
void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 1.
Definition dnnl.hpp:2357
void append_dw_k3s1p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 1.
Definition dnnl.hpp:2331
primitive::kind kind(int index) const
Returns the primitive kind of post-op at entry with a certain index.
Definition dnnl.hpp:2222
void append_eltwise(float scale, algorithm algorithm, float alpha, float beta)
Appends an elementwise post-op.
Definition dnnl.hpp:2280
int len() const
Returns the number of post-ops entries.
Definition dnnl.hpp:2217
void append_dw_k3s2p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 2.
Definition dnnl.hpp:2416
post_ops()
Constructs an empty sequence of post-ops.
Definition dnnl.hpp:2209
void get_params_eltwise(int index, float &scale, algorithm &algorithm, float &alpha, float &beta) const
Returns parameters of an elementwise post-up.
Definition dnnl.hpp:2294
void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 2.
Definition dnnl.hpp:2442
void get_params_sum(int index, float &scale) const
Returns the parameters of an accumulation (sum) post-op.
Definition dnnl.hpp:2260
Primitive attributes.
Definition dnnl.hpp:2481
void get_zero_points(int arg, int &mask, std::vector< int32_t > &zero_points) const
Returns zero points correspondence mask and values.
Definition dnnl.hpp:2648
const post_ops get_post_ops() const
Returns post-ops previously set via set_post_ops().
Definition dnnl.hpp:2694
void set_rnn_data_qparams(float scale, float shift)
Sets quantization scale and shift parameters for RNN data tensors.
Definition dnnl.hpp:2749
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scaling factors correspondence mask and values.
Definition dnnl.hpp:2583
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Sets quantization scaling factors for RNN weights tensors.
Definition dnnl.hpp:2782
void set_scratchpad_mode(scratchpad_mode mode)
Sets scratchpad mode.
Definition dnnl.hpp:2512
void set_scales(int arg, int mask, const std::vector< float > &scales)
Sets scaling factors for primitive operations for a given memory argument.
Definition dnnl.hpp:2631
void get_scales(int arg, int &mask, std::vector< float > &scales) const
Returns scaling factors correspondence mask and values for a given memory argument.
Definition dnnl.hpp:2601
void get_output_scales(int &mask, std::vector< float > &scales) const
Returns output scaling factors correspondence mask and values.
Definition dnnl.hpp:2527
primitive_attr(dnnl_primitive_attr_t attr)
Creates primitive attributes from a C API dnnl_primitive_attr_t handle.
Definition dnnl.hpp:2497
void set_post_ops(const post_ops ops)
Sets post-ops.
Definition dnnl.hpp:2711
primitive_attr()
Constructs default (empty) primitive attributes.
Definition dnnl.hpp:2485
void set_zero_points(int arg, int mask, const std::vector< int32_t > &zero_points)
Sets zero points for primitive operations for a given memory argument.
Definition dnnl.hpp:2683
scratchpad_mode get_scratchpad_mode() const
Returns the scratchpad mode.
Definition dnnl.hpp:2501
Base class for all primitive descriptors.
Definition dnnl.hpp:2796
primitive_attr get_primitive_attr() const
Returns the primitive attributes.
Definition dnnl.hpp:2980
memory::desc diff_weights_desc(int idx) const
Returns a diff weights memory descriptor.
Definition dnnl.hpp:2906
primitive_desc_base()=default
Default constructor. Produces an empty object.
engine get_engine() const
Returns the engine of the primitive descriptor.
Definition dnnl.hpp:2804
memory::desc query_md(query what, int idx=0) const
Returns a memory descriptor.
Definition dnnl.hpp:2841
memory::desc dst_desc(int idx) const
Returns a destination memory descriptor.
Definition dnnl.hpp:2870
memory::desc diff_dst_desc(int idx) const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:2897
memory::desc scratchpad_desc() const
Returns the scratchpad memory descriptor.
Definition dnnl.hpp:2962
void reset_with_clone(const_dnnl_primitive_desc_t pd)
Resets the value of the handle to a clone of a C API primitive descriptor.
Definition dnnl.hpp:3004
dnnl::primitive::kind get_kind() const
Returns the kind of the primitive descriptor.
Definition dnnl.hpp:2992
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:2941
memory::desc diff_src_desc(int idx) const
Returns a diff source memory descriptor.
Definition dnnl.hpp:2888
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition dnnl.hpp:3039
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition dnnl.hpp:2929
const char * impl_info_str() const
Returns implementation name.
Definition dnnl.hpp:2808
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition dnnl.hpp:3056
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition dnnl.hpp:3024
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:2935
memory::desc weights_desc(int idx) const
Returns a weights memory descriptor.
Definition dnnl.hpp:2879
memory::dim query_s64(query what) const
Returns a memory::dim value (same as int64_t).
Definition dnnl.hpp:2820
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:2953
engine scratchpad_engine() const
Returns the engine on which the scratchpad memory is located.
Definition dnnl.hpp:2968
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:2923
memory::desc src_desc(int idx) const
Returns a source memory descriptor.
Definition dnnl.hpp:2861
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:2917
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition dnnl.hpp:2947
A base class for descriptors of all primitives that have an operation descriptor and that support ite...
Definition dnnl.hpp:3458
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr, const engine &engine, const_dnnl_primitive_desc_t hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor.
Definition dnnl.hpp:3485
primitive_desc_base()=default
Default constructor. Produces an empty object.
bool next_impl()
Advances the primitive iterator to the next implementation.
Definition dnnl.hpp:3503
Base class for all computational primitives.
Definition dnnl.hpp:277
primitive()=default
Default constructor. Constructs an empty object.
handle()=default
Constructs an empty handle object.
primitive(const primitive_desc &pd)
Constructs a primitive from a primitive descriptor.
void execute(const stream &stream, const std::unordered_map< int, memory > &args) const
Executes computations specified by the primitive in a specified stream.
kind
Kinds of primitives supported by the library.
Definition dnnl.hpp:282
@ matmul
A matmul (matrix multiplication) primitive.
Definition dnnl.hpp:318
@ sum
A summation primitive.
Definition dnnl.hpp:292
@ deconvolution
A deconvolution primitive.
Definition dnnl.hpp:296
@ inner_product
An inner product primitive.
Definition dnnl.hpp:310
@ logsoftmax
A logsoftmax primitive.
Definition dnnl.hpp:316
@ layer_normalization
A layer normalization primitive.
Definition dnnl.hpp:308
@ concat
A (out-of-place) tensor concatenation primitive.
Definition dnnl.hpp:290
@ pooling
A pooling primitive.
Definition dnnl.hpp:302
@ resampling
A resampling primitive.
Definition dnnl.hpp:320
@ shuffle
A shuffle primitive.
Definition dnnl.hpp:288
@ rnn
A rnn primitive.
Definition dnnl.hpp:312
@ batch_normalization
A batch normalization primitive.
Definition dnnl.hpp:306
@ lrn
An LRN primitive.
Definition dnnl.hpp:304
@ reorder
A reorder primitive.
Definition dnnl.hpp:286
@ eltwise
An element-wise primitive.
Definition dnnl.hpp:298
@ binary
A binary primitive.
Definition dnnl.hpp:314
@ convolution
A convolution primitive.
Definition dnnl.hpp:294
@ softmax
A softmax primitive.
Definition dnnl.hpp:300
@ undef
Undefined primitive.
Definition dnnl.hpp:284
primitive(const_dnnl_primitive_desc_t c_pd)
Constructs a primitive from a C API primitive descriptor.
Primitive descriptor for a reorder primitive.
Definition dnnl.hpp:3120
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:3195
primitive_desc_base()=default
Default constructor. Produces an empty object.
engine get_src_engine() const
Returns the engine on which the source memory is allocated.
Definition dnnl.hpp:3184
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for reorder primitive from a C API primitive descriptor which must ...
Definition dnnl.hpp:3179
engine get_dst_engine() const
Returns the engine on which the destination memory is allocated.
Definition dnnl.hpp:3190
primitive_desc(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for reorder primitive.
Definition dnnl.hpp:3161
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:3198
primitive_desc(const engine &src_engine, const memory::desc &src_md, const engine &dst_engine, const memory::desc &dst_md, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for reorder primitive.
Definition dnnl.hpp:3141
void execute(const stream &stream, memory &src, memory &dst) const
Executes the reorder primitive.
Definition dnnl.hpp:3227
reorder(const primitive_desc &pd)
Constructs a reorder primitive.
Definition dnnl.hpp:3206
reorder()=default
Default constructor. Produces an empty object.
reorder(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr())
Constructs a reorder primitive that would reorder data between memory objects having the same memory ...
Definition dnnl.hpp:3215
Descriptor for a resampling backward propagation primitive.
Definition dnnl.hpp:10278
desc(algorithm algorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for a resampling backward propagation primitive using source and destination ...
Definition dnnl.hpp:10295
desc(algorithm algorithm, const std::vector< float > &factors, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for resampling backward propagation primitive.
Definition dnnl.hpp:10318
Primitive descriptor for resampling backward propagation primitive.
Definition dnnl.hpp:10331
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling backward propagation primitive from a C API primit...
Definition dnnl.hpp:10381
primitive_desc(const desc &desc, const engine &engine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition dnnl.hpp:10348
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:10386
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition dnnl.hpp:10368
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:10389
resampling_backward(const primitive_desc &pd)
Constructs a resampling backward propagation primitive.
Definition dnnl.hpp:10398
resampling_backward()=default
Default constructor. Produces an empty object.
Descriptor for resampling forward propagation.
Definition dnnl.hpp:10119
desc(prop_kind prop_kind, algorithm algorithm, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive using source and destination m...
Definition dnnl.hpp:10143
desc(prop_kind prop_kind, algorithm algorithm, const std::vector< float > &factors, const memory::desc &src_desc)
Constructs a descriptor for a resampling forward propagation primitive using source memory descriptor...
Definition dnnl.hpp:10166
desc(prop_kind prop_kind, algorithm algorithm, const std::vector< float > &factors, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive.
Definition dnnl.hpp:10199
Primitive descriptor for a resampling forward propagation primitive.
Definition dnnl.hpp:10213
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:10263
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition dnnl.hpp:10243
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:10260
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling forward propagation primitive from a C API primiti...
Definition dnnl.hpp:10254
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition dnnl.hpp:10227
primitive_desc()=default
Default constructor. Produces an empty object.
resampling_forward()=default
Default constructor. Produces an empty object.
resampling_forward(const primitive_desc &pd)
Constructs a resampling forward propagation primitive.
Definition dnnl.hpp:10272
Base class for primitive descriptors for RNN primitives.
Definition dnnl.hpp:7274
rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind prop_kind, dnnl::algorithm cell_kind)
Constructs an RNN primitive descriptor base from a C API primitive descriptor while checking that it ...
Definition dnnl.hpp:7287
memory::desc dst_iter_c_desc() const
Returns destination recurrent cell state memory descriptor.
Definition dnnl.hpp:7359
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition dnnl.hpp:7325
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition dnnl.hpp:7385
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:7313
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:7319
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition dnnl.hpp:7373
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition dnnl.hpp:7433
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition dnnl.hpp:7391
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition dnnl.hpp:7427
rnn_primitive_desc_base()=default
Default constructor. Produces an empty object.
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition dnnl.hpp:7379
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition dnnl.hpp:7413
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:7345
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition dnnl.hpp:7404
memory::desc src_iter_c_desc() const
Returns source recurrent cell state memory descriptor.
Definition dnnl.hpp:7307
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:7301
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:7339
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition dnnl.hpp:7331
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:7293
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition dnnl.hpp:7419
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:7353
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition dnnl.hpp:7397
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition dnnl.hpp:7365
Descriptor for a shuffle primitive backward propagation primitive.
Definition dnnl.hpp:9796
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle backward propagation primitive.
Definition dnnl.hpp:9812
Primitive descriptor for a shuffle backward propagation primitive.
Definition dnnl.hpp:9821
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle backward propagation primitive from a C API primitive...
Definition dnnl.hpp:9852
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:9857
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &desc, const engine &engine, const shuffle_forward::primitive_desc &hint_fwd_pd, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle backward propagation primitive.
Definition dnnl.hpp:9839
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition dnnl.hpp:9860
shuffle_backward()=default
Default constructor. Produces an empty object.
shuffle_backward(const primitive_desc &pd)
Constructs a shuffle backward propagation primitive.
Definition dnnl.hpp:9869
Descriptor for a shuffle forward propagation primitive.
Definition dnnl.hpp:9715
desc(prop_kind prop_kind, const memory::desc &data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle forward propagation primitive.
Definition dnnl.hpp:9733
Primitive descriptor for a shuffle forward propagation primitive.
Definition dnnl.hpp:9744
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:9780
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:9777
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle forward propagation primitive from a C API primitive ...
Definition dnnl.hpp:9771
primitive_desc(const desc &desc, const engine &engine, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle forward propagation primitive.
Definition dnnl.hpp:9759
primitive_desc()=default
Default constructor. Produces an empty object.
shuffle_forward()=default
Default constructor. Produces an empty object.
shuffle_forward(const primitive_desc &pd)
Constructs a shuffle forward propagation primitive.
Definition dnnl.hpp:9789
Descriptor for a softmax backward propagation primitive.
Definition dnnl.hpp:5803
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax backward propagation primitive.
Definition dnnl.hpp:5823
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax backward propagation primitive.
Definition dnnl.hpp:5834
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax backward propagation primitive from a C API primitive...
Definition dnnl.hpp:5884
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:5895
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition dnnl.hpp:5892
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition dnnl.hpp:5871
primitive_desc(const desc &desc, const engine &engine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition dnnl.hpp:5851
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:5889
softmax_backward()=default
Default constructor. Produces an empty object.
softmax_backward(const primitive_desc &pd)
Constructs a softmax backward propagation primitive.
Definition dnnl.hpp:5904
Descriptor for a softmax forward propagation primitive.
Definition dnnl.hpp:5707
desc(prop_kind prop_kind, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax forward propagation primitive.
Definition dnnl.hpp:5727
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax forward propagation primitive.
Definition dnnl.hpp:5738
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition dnnl.hpp:5752
memory::desc src_desc() const
Returns a source memory descriptor.
Definition dnnl.hpp:5785
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:5788
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax forward propagation primitive from a C API primitive ...
Definition dnnl.hpp:5779
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition dnnl.hpp:5768
primitive_desc()=default
Default constructor. Produces an empty object.
softmax_forward()=default
Default constructor. Produces an empty object.
softmax_forward(const primitive_desc &pd)
Constructs a softmax forward propagation primitive.
Definition dnnl.hpp:5797
A container for stream attributes.
Definition dnnl.hpp:1002
stream_attr(engine::kind kind)
Constructs stream attributes for a stream that runs on an engine of a particular kind.
Definition dnnl.hpp:1012
handle()=default
Constructs an empty handle object.
stream_attr()=default
Constructs default (empty) stream attributes.
An execution stream.
Definition dnnl.hpp:1047
stream & wait()
Waits for all primitives executing in the stream to finish.
Definition dnnl.hpp:1107
cl_command_queue get_ocl_command_queue() const
Returns the underlying OpenCL queue object.
Definition dnnl.hpp:1097
stream(const engine &engine, flags flags=flags::default_flags, const stream_attr &attr=stream_attr())
Constructs a stream for the specified engine and with behavior controlled by the specified flags.
Definition dnnl.hpp:1073
handle()=default
Constructs an empty handle object.
flags
Stream flags. Can be combined using the bitwise OR operator.
Definition dnnl.hpp:1051
@ out_of_order
Out-of-order execution.
Definition dnnl.hpp:1058
@ default_order
Default order execution.
Definition dnnl.hpp:1054
@ default_flags
Default stream configuration.
Definition dnnl.hpp:1060
@ in_order
In-order execution.
Definition dnnl.hpp:1056
stream()=default
Constructs an empty stream.
stream(const engine &engine, cl_command_queue queue)
Constructs a stream for the specified engine and the OpenCL queue.
Definition dnnl.hpp:1088
Primitive descriptor for a sum primitive.
Definition dnnl.hpp:3358
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition dnnl.hpp:3440
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc_base()=default
Default constructor. Produces an empty object.
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition dnnl.hpp:3437
primitive_desc(const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &engine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition dnnl.hpp:3411
primitive_desc(const memory::desc &dst, const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &engine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition dnnl.hpp:3381
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for sum primitive from a C API primitive descriptor which must have...
Definition dnnl.hpp:3433
sum()=default
Default constructor. Produces an empty object.
sum(const primitive_desc &pd)
Constructs a sum primitive.
Definition dnnl.hpp:3448
Abstract threadpool interface.
Definition dnnl_threadpool_iface.hpp:27
Vanilla RNN descriptor backward propagation primitive.
Definition dnnl.hpp:7657
desc(prop_kind prop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN backward propagation primitive.
Definition dnnl.hpp:7735
Primitive descriptor for a RNN backward propagation primitive.
Definition dnnl.hpp:7771
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:7831
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition dnnl.hpp:7885
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:7847
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition dnnl.hpp:7865
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition dnnl.hpp:7875
primitive_desc(const desc &desc, const engine &engine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition dnnl.hpp:7788
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition dnnl.hpp:7880
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:7839
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive from a C API primi...
Definition dnnl.hpp:7821
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition dnnl.hpp:7808
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:7834
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:7844
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:7852
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition dnnl.hpp:7890
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition dnnl.hpp:7860
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:7826
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition dnnl.hpp:7870
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:7855
vanilla_rnn_backward(const primitive_desc &pd)
Constructs a vanilla RNN backward propagation primitive.
Definition dnnl.hpp:7901
vanilla_rnn_backward()=default
Default constructor. Produces an empty object.
Descriptor for a vanilla RNN forward propagation primitive.
Definition dnnl.hpp:7483
desc(prop_kind prop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN forward propagation primitive.
Definition dnnl.hpp:7539
Primitive descriptor for a vanilla RNN forward propagation primitive.
Definition dnnl.hpp:7564
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition dnnl.hpp:7594
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive from a C API primit...
Definition dnnl.hpp:7605
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition dnnl.hpp:7611
primitive_desc(const desc &desc, const engine &engine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition dnnl.hpp:7578
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition dnnl.hpp:7616
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition dnnl.hpp:7624
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition dnnl.hpp:7619
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition dnnl.hpp:7640
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition dnnl.hpp:7637
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition dnnl.hpp:7632
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition dnnl.hpp:7629
vanilla_rnn_forward()=default
Default constructor. Produces an empty object.
vanilla_rnn_forward(const primitive_desc &pd)
Constructs a vanilla RNN forward propagation primitive.
Definition dnnl.hpp:7651
A descriptor of a Batch Normalization operation.
Definition dnnl_types.h:1360
A descriptor of a binary operation.
Definition dnnl_types.h:1568
A descriptor of a convolution operation.
Definition dnnl_types.h:1134
A descriptor of a element-wise operation.
Definition dnnl_types.h:1209
A descriptor of an inner product operation.
Definition dnnl_types.h:1430
A descriptor of a Layer Normalization operation.
Definition dnnl_types.h:1393
A descriptor of a Local Response Normalization (LRN) operation.
Definition dnnl_types.h:1329
A descriptor of a matrix multiplication operation.
Definition dnnl_types.h:1594
Memory descriptor.
Definition dnnl_types.h:1050
int ndims
Number of dimensions.
Definition dnnl_types.h:1052
A descriptor of a pooling operation.
Definition dnnl_types.h:1291
A descriptor of resampling operation.
Definition dnnl_types.h:1616
A descriptor for an RNN operation.
Definition dnnl_types.h:1486
A descriptor of a shuffle operation.
Definition dnnl_types.h:1187
A descriptor of a Softmax operation.
Definition dnnl_types.h:1261
Structure containing version information as per Semantic Versioning
Definition dnnl_types.h:2120