33 #include <cudnn_backend.h> 61 ss <<
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR :" 69 case CUDNN_POINTWISE_ADD:
70 case CUDNN_POINTWISE_MUL:
71 case CUDNN_POINTWISE_MIN:
72 case CUDNN_POINTWISE_MAX:
73 case CUDNN_POINTWISE_RELU_BWD:
74 case CUDNN_POINTWISE_TANH_BWD:
75 case CUDNN_POINTWISE_SIGMOID_BWD:
76 case CUDNN_POINTWISE_ELU_BWD:
77 case CUDNN_POINTWISE_GELU_BWD:
78 case CUDNN_POINTWISE_SOFTPLUS_BWD:
79 case CUDNN_POINTWISE_SWISH_BWD:
81 case CUDNN_POINTWISE_SQRT:
82 case CUDNN_POINTWISE_RELU_FWD:
83 case CUDNN_POINTWISE_TANH_FWD:
84 case CUDNN_POINTWISE_SIGMOID_FWD:
85 case CUDNN_POINTWISE_ELU_FWD:
86 case CUDNN_POINTWISE_GELU_FWD:
87 case CUDNN_POINTWISE_SOFTPLUS_FWD:
88 case CUDNN_POINTWISE_SWISH_FWD:
121 cudnnPointwiseMode_t
mode = CUDNN_POINTWISE_ADD;
143 m_pointWiseDesc.math_precision = data_type_;
149 m_pointWiseDesc.upper_clip = u;
150 m_pointWiseDesc.lower_clip = l;
156 m_pointWiseDesc.mode = mode_;
162 m_pointWiseDesc.nan_propagation = nan_mode_;
169 m_pointWiseDesc.lower_clip = lower_clip_;
175 m_pointWiseDesc.upper_clip = upper_clip_;
181 m_pointWiseDesc.lower_clip_slope = lower_clip_slope_;
187 m_pointWiseDesc.elu_alpha = elu_alpha_;
193 m_pointWiseDesc.softplus_beta = softplus_beta_;
199 m_pointWiseDesc.swish_beta = swish_beta_;
208 auto status = m_pointWiseDesc.initialize_managed_backend_pointer(CUDNN_BACKEND_POINTWISE_DESCRIPTOR);
209 if (
status != CUDNN_STATUS_SUCCESS) {
211 &m_pointWiseDesc,
status,
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR: cudnnCreate Failed");
212 return std::move(m_pointWiseDesc);
216 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
217 CUDNN_ATTR_POINTWISE_MODE,
218 CUDNN_TYPE_POINTWISE_MODE,
220 &m_pointWiseDesc.mode);
221 if (
status != CUDNN_STATUS_SUCCESS) {
225 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: CUDNN_TYPE_POINTWISE_MODE SetAttribute Failed");
226 return std::move(m_pointWiseDesc);
229 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
230 CUDNN_ATTR_POINTWISE_MATH_PREC,
231 CUDNN_TYPE_DATA_TYPE,
233 &m_pointWiseDesc.math_precision);
234 if (
status != CUDNN_STATUS_SUCCESS) {
238 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_MATH_PREC Failed");
239 return std::move(m_pointWiseDesc);
242 if (m_pointWiseDesc.mode == CUDNN_POINTWISE_RELU_FWD || m_pointWiseDesc.mode == CUDNN_POINTWISE_RELU_BWD) {
243 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
244 CUDNN_ATTR_POINTWISE_NAN_PROPAGATION,
245 CUDNN_TYPE_NAN_PROPOGATION,
247 &m_pointWiseDesc.nan_propagation);
248 if (
status != CUDNN_STATUS_SUCCESS) {
252 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_NAN_PROPAGATION Failed");
253 return std::move(m_pointWiseDesc);
256 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
257 CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP,
260 &m_pointWiseDesc.lower_clip);
261 if (
status != CUDNN_STATUS_SUCCESS) {
265 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP, Failed");
266 return std::move(m_pointWiseDesc);
269 if (m_pointWiseDesc.math_precision == CUDNN_DATA_FLOAT) {
270 double clamped_upper_clip =
271 std::min<double>(m_pointWiseDesc.upper_clip, std::numeric_limits<float>::max());
272 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
273 CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP,
276 &clamped_upper_clip);
279 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
280 CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP,
283 &m_pointWiseDesc.upper_clip);
285 if (
status != CUDNN_STATUS_SUCCESS) {
289 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP, Failed");
290 return std::move(m_pointWiseDesc);
293 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
294 CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE,
297 &m_pointWiseDesc.lower_clip_slope);
298 if (
status != CUDNN_STATUS_SUCCESS) {
301 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute " 302 "CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE, Failed");
303 return std::move(m_pointWiseDesc);
305 }
else if (m_pointWiseDesc.mode == CUDNN_POINTWISE_ELU_FWD || m_pointWiseDesc.mode == CUDNN_POINTWISE_ELU_BWD) {
306 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
307 CUDNN_ATTR_POINTWISE_ELU_ALPHA,
310 &m_pointWiseDesc.elu_alpha);
311 if (
status != CUDNN_STATUS_SUCCESS) {
315 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_ELU_ALPHA, Failed");
316 return std::move(m_pointWiseDesc);
318 }
else if (m_pointWiseDesc.mode == CUDNN_POINTWISE_SOFTPLUS_FWD ||
319 m_pointWiseDesc.mode == CUDNN_POINTWISE_SOFTPLUS_BWD) {
320 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
321 CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA,
324 &m_pointWiseDesc.softplus_beta);
325 if (
status != CUDNN_STATUS_SUCCESS) {
329 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA, Failed");
330 return std::move(m_pointWiseDesc);
332 }
else if (m_pointWiseDesc.mode == CUDNN_POINTWISE_SWISH_FWD ||
333 m_pointWiseDesc.mode == CUDNN_POINTWISE_SWISH_BWD) {
334 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
335 CUDNN_ATTR_POINTWISE_SWISH_BETA,
338 &m_pointWiseDesc.swish_beta);
339 if (
status != CUDNN_STATUS_SUCCESS) {
343 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_SWISH_BETA, Failed");
344 return std::move(m_pointWiseDesc);
349 status = cudnnBackendFinalize(m_pointWiseDesc.pointer->get_backend_descriptor());
350 if (
status != CUDNN_STATUS_SUCCESS) {
352 &m_pointWiseDesc,
status,
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR: cudnnFinalize Failed");
353 return std::move(m_pointWiseDesc);
356 return std::move(m_pointWiseDesc);
PointWiseDesc_v8()=default
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setClipping(double l, double u) -> PointWiseDescBuilder_v8 &
Set upper and lower limits for the RELU activation.
PointWiseDesc_v8 & operator=(PointWiseDesc_v8 const &)=delete
auto setMode(cudnnNanPropagation_t nan_mode_) -> PointWiseDescBuilder_v8 &
Set NaN propagation mode.
auto setSwishBeta(double swish_beta_) -> PointWiseDescBuilder_v8 &
PointWiseDesc_v8(PointWiseDesc_v8 &&from)
cudnnPointwiseMode_t getPointWiseMode() const
cudnnNanPropagation_t nan_propagation
~PointWiseDesc_v8()=default
ManagedOpaqueDescriptor get_desc() const
Returns a copy of underlying managed descriptor.
auto setReluLowerClip(double lower_clip_) -> PointWiseDescBuilder_v8 &
auto setSoftplusBeta(double softplus_beta_) -> PointWiseDescBuilder_v8 &
friend class PointWiseDescBuilder_v8
std::string describe() const override
Return a string describing the backend Descriptor.
auto setReluLowerClipSlope(double lower_clip_slope_) -> PointWiseDescBuilder_v8 &
int64_t getPortCount() const
cudnnStatus_t get_status() const
Current status of the descriptor.
PointWiseDesc_v8 m_pointWiseDesc
const char * get_error() const
Diagonistic error message if any.
cudnnDataType_t math_precision
auto setMathPrecision(cudnnDataType_t data_type_) -> PointWiseDescBuilder_v8 &
Set Math Precision Data Type for the Convolution Operation.
auto setMode(cudnnPointwiseMode_t mode_) -> PointWiseDescBuilder_v8 &
Set pointwise mode for the activation.
auto setEluAlpha(double elu_alpha_) -> PointWiseDescBuilder_v8 &
cudnnPointwiseMode_t mode
auto setReluUpperClip(double upper_clip_) -> PointWiseDescBuilder_v8 &
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.
PointWiseDesc_v8 && build()