33 #include <cudnn_backend.h> 73 ss <<
"CUDNN_BACKEND_OPERATION :" 75 ss << std::hex <<
" X " <<
xdesc;
76 ss << std::hex <<
" Y " <<
ydesc;
77 ss << std::hex <<
" W " <<
wdesc;
78 ss << std::hex <<
" B " <<
bdesc;
79 ss << std::hex <<
" DW " <<
dwdesc;
80 ss << std::hex <<
" DY " <<
dydesc;
81 ss << std::hex <<
" DX " <<
dxdesc;
82 ss << std::hex <<
" C " <<
cdesc;
83 ss << std::hex <<
" A Mtrix " <<
amatdesc;
84 ss << std::hex <<
" B Mtrix " <<
bmatdesc;
85 ss << std::hex <<
" C Mtrix " <<
cmatdesc;
86 ss << std::hex <<
" P " <<
pwdesc;
141 cudnnBackendDescriptorType_t
op_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
176 bool is_convolution_op =
false;
177 bool is_pointwise_op =
false;
178 bool is_matmul_op =
false;
179 bool is_reduction_op =
false;
188 m_operation.
xdesc = raw_tensor;
194 m_operation.
xdesc = tensor.get_desc();
199 if (is_pointwise_op ==
false) {
202 CUDNN_STATUS_BAD_PARAM,
203 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Pointwise operation does not need bTensor");
205 m_operation.
bdesc = tensor.get_desc();
210 m_operation.
ydesc = tensor.get_desc();
215 if (is_convolution_op ==
false) {
218 CUDNN_STATUS_BAD_PARAM,
219 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need wTensor");
221 m_operation.
wdesc = tensor.get_desc();
227 m_operation.
dydesc = raw_tensor;
232 m_operation.
dydesc = tensor.get_desc();
237 m_operation.
dxdesc = tensor.get_desc();
242 m_operation.
dwdesc = tensor.get_desc();
248 if (is_convolution_op ==
false) {
251 CUDNN_STATUS_BAD_PARAM,
252 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need Convolution DESCRIPTOR");
254 m_operation.
cdesc = conv.get_desc();
259 if (is_matmul_op ==
false) {
262 CUDNN_STATUS_BAD_PARAM,
263 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need a Matrix Tensor");
265 m_operation.
amatdesc = tensor.get_desc();
270 if (is_matmul_op ==
false) {
273 CUDNN_STATUS_BAD_PARAM,
274 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need b Matrix Tensor");
276 m_operation.
bmatdesc = tensor.get_desc();
281 if (is_matmul_op ==
false) {
284 CUDNN_STATUS_BAD_PARAM,
285 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need c Matrix Tensor");
287 m_operation.
cmatdesc = tensor.get_desc();
292 if (is_matmul_op ==
false) {
295 CUDNN_STATUS_BAD_PARAM,
296 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need MATMUL DESCRIPTOR");
298 m_operation.
matmuldesc = matmulDesc.get_desc();
303 if (is_reduction_op ==
false) {
306 CUDNN_STATUS_BAD_PARAM,
307 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Reduction operation does not need REDUCTION DESCRIPTOR");
314 if (is_pointwise_op ==
false) {
317 CUDNN_STATUS_BAD_PARAM,
318 "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Pointwise operation does not need POINTWISE DESCRIPTOR");
320 m_operation.
pwdesc = pointWiseDesc.get_desc();
352 m_operation.
alpha_d =
static_cast<double>(alpha);
359 m_operation.
alpha_s =
static_cast<float>(alpha);
366 m_operation.
alpha2_d =
static_cast<double>(alpha);
373 m_operation.
alpha2_s =
static_cast<float>(alpha);
380 m_operation.
beta_d =
static_cast<double>(beta);
381 m_operation.
beta_s = beta;
387 m_operation.
beta_s =
static_cast<float>(beta);
388 m_operation.
beta_d = beta;
394 is_convolution_op = ((m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) ||
395 (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) ||
396 (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR));
398 is_pointwise_op = (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR);
399 is_matmul_op = (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
400 is_reduction_op = (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR);
408 if (m_operation.
status != CUDNN_STATUS_SUCCESS) {
410 &m_operation, m_operation.
status,
"CUDNN_BACKEND_OPERATION: Operation not initialized properly");
411 return std::move(m_operation);
414 if (is_convolution_op) {
415 if (m_operation.
cdesc ==
nullptr) {
418 CUDNN_STATUS_BAD_PARAM,
419 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_CONV_DESC");
420 return std::move(m_operation);
422 if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
423 if (m_operation.
xdesc ==
nullptr) {
426 CUDNN_STATUS_BAD_PARAM,
427 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_X");
428 return std::move(m_operation);
430 if (m_operation.
wdesc ==
nullptr) {
433 CUDNN_STATUS_BAD_PARAM,
434 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_W");
435 return std::move(m_operation);
437 if (m_operation.
ydesc ==
nullptr) {
440 CUDNN_STATUS_BAD_PARAM,
441 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_Y");
442 return std::move(m_operation);
445 }
else if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
446 if (m_operation.
ydesc !=
nullptr && m_operation.
dydesc !=
nullptr) {
448 CUDNN_STATUS_BAD_PARAM,
449 "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set " 450 "only one of setyDesc() or setdyDesc()");
451 return std::move(m_operation);
453 if (m_operation.
ydesc ==
nullptr && m_operation.
dydesc ==
nullptr) {
456 CUDNN_STATUS_BAD_PARAM,
457 "CUDNN_BACKEND_OPERATION: Choose and Set one of setyDesc() or setdyDesc()");
458 return std::move(m_operation);
460 if (m_operation.
xdesc ==
nullptr) {
463 CUDNN_STATUS_BAD_PARAM,
464 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_X");
465 return std::move(m_operation);
467 if (m_operation.
wdesc !=
nullptr && m_operation.
dwdesc !=
nullptr) {
469 CUDNN_STATUS_BAD_PARAM,
470 "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set " 471 "only one of setwDesc() or setdwDesc()");
472 return std::move(m_operation);
474 if (m_operation.
wdesc ==
nullptr && m_operation.
dwdesc ==
nullptr) {
477 CUDNN_STATUS_BAD_PARAM,
478 "CUDNN_BACKEND_OPERATION: Choose and Set one of setwDesc() or setdwDesc()");
479 return std::move(m_operation);
481 }
else if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
482 if (m_operation.
ydesc !=
nullptr && m_operation.
dydesc !=
nullptr) {
484 CUDNN_STATUS_BAD_PARAM,
485 "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set " 486 "only one of setyDesc() or setdyDesc()");
487 return std::move(m_operation);
489 if (m_operation.
ydesc ==
nullptr && m_operation.
dydesc ==
nullptr) {
492 CUDNN_STATUS_BAD_PARAM,
493 "CUDNN_BACKEND_OPERATION: Choose and Set one of setyDesc() or setdyDesc()");
494 return std::move(m_operation);
496 if (m_operation.
wdesc ==
nullptr) {
499 CUDNN_STATUS_BAD_PARAM,
500 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_W");
501 return std::move(m_operation);
503 if (m_operation.
xdesc !=
nullptr && m_operation.
dxdesc !=
nullptr) {
505 CUDNN_STATUS_BAD_PARAM,
506 "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set " 507 "only one of setxDesc() or setdxDesc()");
508 return std::move(m_operation);
510 if (m_operation.
xdesc ==
nullptr && m_operation.
dxdesc ==
nullptr) {
513 CUDNN_STATUS_BAD_PARAM,
514 "CUDNN_BACKEND_OPERATION: Choose and Set one of setxDesc() or setdxDesc()");
515 return std::move(m_operation);
519 CUDNN_STATUS_BAD_PARAM,
520 "CUDNN_BACKEND_OPERATION: Unsupported convolution operation. Check and " 521 "set CUDNN_BACKEND_OPERATION_CONVOLUTION_*_DESCRIPTOR");
522 return std::move(m_operation);
524 }
else if (is_pointwise_op) {
525 if (m_operation.
xdesc ==
nullptr) {
528 CUDNN_STATUS_BAD_PARAM,
529 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_XDESC");
530 return std::move(m_operation);
537 CUDNN_STATUS_BAD_PARAM,
538 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_BDESC");
539 return std::move(m_operation);
541 if (m_operation.
ydesc ==
nullptr) {
544 CUDNN_STATUS_BAD_PARAM,
545 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_YDESC");
546 return std::move(m_operation);
549 if (m_operation.
ydesc ==
nullptr) {
552 CUDNN_STATUS_BAD_PARAM,
553 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_YDESC");
554 return std::move(m_operation);
557 if (m_operation.
dydesc ==
nullptr) {
560 CUDNN_STATUS_BAD_PARAM,
561 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_DYDESC");
562 return std::move(m_operation);
564 if (m_operation.
dxdesc ==
nullptr) {
567 CUDNN_STATUS_BAD_PARAM,
568 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_DXDESC");
569 return std::move(m_operation);
574 CUDNN_STATUS_BAD_PARAM,
575 "CUDNN_BACKEND_OPERATION: Unsupported cudnn pointwise mode. Check and set CUDNN_POINTWISE_*");
576 return std::move(m_operation);
579 }
else if (is_matmul_op) {
583 CUDNN_STATUS_BAD_PARAM,
584 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_DESC");
585 return std::move(m_operation);
587 if (m_operation.
amatdesc ==
nullptr) {
590 CUDNN_STATUS_BAD_PARAM,
591 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_ADESC");
592 return std::move(m_operation);
594 if (m_operation.
bmatdesc ==
nullptr) {
597 CUDNN_STATUS_BAD_PARAM,
598 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_BDESC");
599 return std::move(m_operation);
601 if (m_operation.
cmatdesc ==
nullptr) {
604 CUDNN_STATUS_BAD_PARAM,
605 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_CDESC");
606 return std::move(m_operation);
608 }
else if (is_reduction_op) {
612 CUDNN_STATUS_BAD_PARAM,
613 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_DESC");
614 return std::move(m_operation);
616 if (m_operation.
xdesc ==
nullptr) {
619 CUDNN_STATUS_BAD_PARAM,
620 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_XDESC");
621 return std::move(m_operation);
623 if (m_operation.
ydesc ==
nullptr) {
626 CUDNN_STATUS_BAD_PARAM,
627 "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_YDESC");
628 return std::move(m_operation);
632 CUDNN_STATUS_BAD_PARAM,
633 "CUDNN_BACKEND_OPERATION_DESCRIPTOR: Unsupported cudnn backend descriptor " 634 "type. Check and set CUDNN_BACKEND_OPERATION_*_DESCRIPTOR");
635 return std::move(m_operation);
640 if (
status != CUDNN_STATUS_SUCCESS) {
642 return std::move(m_operation);
645 if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
648 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
649 CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X,
650 CUDNN_TYPE_BACKEND_DESCRIPTOR,
652 &(m_operation.
xdesc->get_backend_descriptor()));
653 if (
status != CUDNN_STATUS_SUCCESS) {
657 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X Failed");
658 return std::move(m_operation);
660 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
661 CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W,
662 CUDNN_TYPE_BACKEND_DESCRIPTOR,
664 &(m_operation.
wdesc->get_backend_descriptor()));
665 if (
status != CUDNN_STATUS_SUCCESS) {
669 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W Failed");
670 return std::move(m_operation);
672 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
673 CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y,
674 CUDNN_TYPE_BACKEND_DESCRIPTOR,
676 &(m_operation.
ydesc->get_backend_descriptor()));
677 if (
status != CUDNN_STATUS_SUCCESS) {
681 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y Failed");
682 return std::move(m_operation);
684 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
685 CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC,
686 CUDNN_TYPE_BACKEND_DESCRIPTOR,
688 &(m_operation.
cdesc->get_backend_descriptor()));
689 if (
status != CUDNN_STATUS_SUCCESS) {
693 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC Failed");
694 return std::move(m_operation);
696 void *alpha = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
alpha_s)
697 : static_cast<void *>(&m_operation.
alpha_d));
698 void *beta = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
beta_s)
699 : static_cast<void *>(&m_operation.
beta_d));
700 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
701 CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA,
705 if (
status != CUDNN_STATUS_SUCCESS) {
709 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA Failed");
710 return std::move(m_operation);
712 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
713 CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA,
717 if (
status != CUDNN_STATUS_SUCCESS) {
721 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA Failed");
722 return std::move(m_operation);
724 }
else if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
727 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
728 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X,
729 CUDNN_TYPE_BACKEND_DESCRIPTOR,
731 &(m_operation.
xdesc->get_backend_descriptor()));
732 if (
status != CUDNN_STATUS_SUCCESS) {
736 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X Failed");
737 return std::move(m_operation);
740 auto dwdesc_ = m_operation.
dwdesc !=
nullptr ? m_operation.
dwdesc : m_operation.
wdesc;
741 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
742 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW,
743 CUDNN_TYPE_BACKEND_DESCRIPTOR,
745 &(dwdesc_->get_backend_descriptor()));
746 if (
status != CUDNN_STATUS_SUCCESS) {
750 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW Failed");
751 return std::move(m_operation);
754 auto dydesc_ = m_operation.
dydesc !=
nullptr ? m_operation.
dydesc : m_operation.
ydesc;
755 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
756 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY,
757 CUDNN_TYPE_BACKEND_DESCRIPTOR,
759 &(dydesc_->get_backend_descriptor()));
760 if (
status != CUDNN_STATUS_SUCCESS) {
764 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY Failed");
765 return std::move(m_operation);
768 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
769 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC,
770 CUDNN_TYPE_BACKEND_DESCRIPTOR,
772 &(m_operation.
cdesc->get_backend_descriptor()));
773 if (
status != CUDNN_STATUS_SUCCESS) {
776 "CUDNN_BACKEND_OPERATION: SetAttribute " 777 "CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC Failed");
778 return std::move(m_operation);
780 void *alpha = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
alpha_s)
781 : static_cast<void *>(&m_operation.
alpha_d));
782 void *beta = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
beta_s)
783 : static_cast<void *>(&m_operation.
beta_d));
784 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
785 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
789 if (
status != CUDNN_STATUS_SUCCESS) {
793 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA Failed");
794 return std::move(m_operation);
796 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
797 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
801 if (
status != CUDNN_STATUS_SUCCESS) {
805 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA Failed");
806 return std::move(m_operation);
808 }
else if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
811 auto dxdesc_ = m_operation.
dxdesc !=
nullptr ? m_operation.
dxdesc : m_operation.
xdesc;
812 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
813 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX,
814 CUDNN_TYPE_BACKEND_DESCRIPTOR,
816 &(dxdesc_->get_backend_descriptor()));
817 if (
status != CUDNN_STATUS_SUCCESS) {
821 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX Failed");
822 return std::move(m_operation);
825 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
826 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W,
827 CUDNN_TYPE_BACKEND_DESCRIPTOR,
829 &(m_operation.
wdesc->get_backend_descriptor()));
830 if (
status != CUDNN_STATUS_SUCCESS) {
834 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W Failed");
835 return std::move(m_operation);
838 auto dydesc_ = m_operation.
dydesc !=
nullptr ? m_operation.
dydesc : m_operation.
ydesc;
839 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
840 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY,
841 CUDNN_TYPE_BACKEND_DESCRIPTOR,
843 &(dydesc_->get_backend_descriptor()));
844 if (
status != CUDNN_STATUS_SUCCESS) {
848 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY Failed");
849 return std::move(m_operation);
852 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
853 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC,
854 CUDNN_TYPE_BACKEND_DESCRIPTOR,
856 &(m_operation.
cdesc->get_backend_descriptor()));
857 if (
status != CUDNN_STATUS_SUCCESS) {
861 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC Failed");
862 return std::move(m_operation);
865 void *alpha = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
alpha_s)
866 : static_cast<void *>(&m_operation.
alpha_d));
867 void *beta = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
beta_s)
868 : static_cast<void *>(&m_operation.
beta_d));
869 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
870 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
874 if (
status != CUDNN_STATUS_SUCCESS) {
878 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA Failed");
879 return std::move(m_operation);
881 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
882 CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
886 if (
status != CUDNN_STATUS_SUCCESS) {
890 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA Failed");
891 return std::move(m_operation);
893 }
else if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) {
895 case CUDNN_POINTWISE_ADD:
898 case CUDNN_POINTWISE_MUL:
901 case CUDNN_POINTWISE_MIN:
904 case CUDNN_POINTWISE_MAX:
907 case CUDNN_POINTWISE_SQRT:
910 case CUDNN_POINTWISE_RELU_FWD:
913 case CUDNN_POINTWISE_TANH_FWD:
916 case CUDNN_POINTWISE_SIGMOID_FWD:
919 case CUDNN_POINTWISE_ELU_FWD:
922 case CUDNN_POINTWISE_GELU_FWD:
925 case CUDNN_POINTWISE_SOFTPLUS_FWD:
928 case CUDNN_POINTWISE_SWISH_FWD:
931 case CUDNN_POINTWISE_RELU_BWD:
934 case CUDNN_POINTWISE_TANH_BWD:
937 case CUDNN_POINTWISE_SIGMOID_BWD:
940 case CUDNN_POINTWISE_ELU_BWD:
943 case CUDNN_POINTWISE_GELU_BWD:
946 case CUDNN_POINTWISE_SOFTPLUS_BWD:
949 case CUDNN_POINTWISE_SWISH_BWD:
957 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
958 CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR,
959 CUDNN_TYPE_BACKEND_DESCRIPTOR,
961 &(m_operation.
pwdesc->get_backend_descriptor()));
962 if (
status != CUDNN_STATUS_SUCCESS) {
966 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR Failed");
967 return std::move(m_operation);
970 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
971 CUDNN_ATTR_OPERATION_POINTWISE_XDESC,
972 CUDNN_TYPE_BACKEND_DESCRIPTOR,
974 &(m_operation.
xdesc->get_backend_descriptor()));
975 if (
status != CUDNN_STATUS_SUCCESS) {
979 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_XDESC Failed");
980 return std::move(m_operation);
984 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
985 CUDNN_ATTR_OPERATION_POINTWISE_YDESC,
986 CUDNN_TYPE_BACKEND_DESCRIPTOR,
988 &(m_operation.
ydesc->get_backend_descriptor()));
989 if (
status != CUDNN_STATUS_SUCCESS) {
993 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_YDESC Failed");
994 return std::move(m_operation);
997 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
998 CUDNN_ATTR_OPERATION_POINTWISE_DYDESC,
999 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1001 &(m_operation.
dydesc->get_backend_descriptor()));
1002 if (
status != CUDNN_STATUS_SUCCESS) {
1006 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_DYDESC Failed");
1007 return std::move(m_operation);
1010 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1011 CUDNN_ATTR_OPERATION_POINTWISE_DXDESC,
1012 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1014 &(m_operation.
dxdesc->get_backend_descriptor()));
1015 if (
status != CUDNN_STATUS_SUCCESS) {
1019 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_DXDESC Failed");
1020 return std::move(m_operation);
1024 void *alpha = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
alpha_s)
1025 : static_cast<void *>(&m_operation.
alpha_d));
1026 void *alpha2 = (m_operation.
alphabetaType == CUDNN_TYPE_FLOAT ?
static_cast<void *
>(&m_operation.
alpha2_s)
1027 : static_cast<void *>(&m_operation.
alpha2_d));
1028 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1029 CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1,
1033 if (
status != CUDNN_STATUS_SUCCESS) {
1037 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 Failed");
1038 return std::move(m_operation);
1040 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1041 CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2,
1045 if (
status != CUDNN_STATUS_SUCCESS) {
1049 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 Failed");
1050 return std::move(m_operation);
1054 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1055 CUDNN_ATTR_OPERATION_POINTWISE_BDESC,
1056 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1058 &(m_operation.
bdesc->get_backend_descriptor()));
1059 if (
status != CUDNN_STATUS_SUCCESS) {
1063 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_BDESC Failed");
1064 return std::move(m_operation);
1067 }
else if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) {
1069 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1070 CUDNN_ATTR_OPERATION_MATMUL_ADESC,
1071 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1073 &(m_operation.
amatdesc->get_backend_descriptor()));
1074 if (
status != CUDNN_STATUS_SUCCESS) {
1078 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_ADESC Failed");
1079 return std::move(m_operation);
1081 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1082 CUDNN_ATTR_OPERATION_MATMUL_BDESC,
1083 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1085 &(m_operation.
bmatdesc->get_backend_descriptor()));
1086 if (
status != CUDNN_STATUS_SUCCESS) {
1090 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_BDESC Failed");
1091 return std::move(m_operation);
1093 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1094 CUDNN_ATTR_OPERATION_MATMUL_CDESC,
1095 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1097 &(m_operation.
cmatdesc->get_backend_descriptor()));
1098 if (
status != CUDNN_STATUS_SUCCESS) {
1102 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_CDESC Failed");
1103 return std::move(m_operation);
1105 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1106 CUDNN_ATTR_OPERATION_MATMUL_DESC,
1107 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1109 &(m_operation.
matmuldesc->get_backend_descriptor()));
1110 if (
status != CUDNN_STATUS_SUCCESS) {
1114 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_DESC Failed");
1115 return std::move(m_operation);
1117 }
else if (m_operation.
op_mode == CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) {
1119 if ((cudnnGetVersion() / 100) == 81) {
1120 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1121 CUDNN_ATTR_REDUCTION_OPERATOR,
1122 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1126 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1127 CUDNN_ATTR_OPERATION_REDUCTION_DESC,
1128 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1132 if (
status != CUDNN_STATUS_SUCCESS) {
1136 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_DESC Failed");
1137 return std::move(m_operation);
1139 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1140 CUDNN_ATTR_OPERATION_REDUCTION_XDESC,
1141 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1143 &(m_operation.
xdesc->get_backend_descriptor()));
1144 if (
status != CUDNN_STATUS_SUCCESS) {
1148 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_XDESC Failed");
1149 return std::move(m_operation);
1151 status = cudnnBackendSetAttribute(m_operation.
pointer->get_backend_descriptor(),
1152 CUDNN_ATTR_OPERATION_REDUCTION_YDESC,
1153 CUDNN_TYPE_BACKEND_DESCRIPTOR,
1155 &(m_operation.
ydesc->get_backend_descriptor()));
1156 if (
status != CUDNN_STATUS_SUCCESS) {
1160 "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_YDESC Failed");
1161 return std::move(m_operation);
1164 status = cudnnBackendFinalize(m_operation.
pointer->get_backend_descriptor());
1165 if (
status != CUDNN_STATUS_SUCCESS) {
1167 return std::move(m_operation);
1169 return std::move(m_operation);
auto setcDesc(ConvDesc_v8 const &conv) -> OperationBuilder_v8 &
cudnnStatus_t initialize_managed_backend_pointer(cudnnBackendDescriptorType_t type)
Initializes the underlying managed descriptor.
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
NLOHMANN_BASIC_JSON_TPL_DECLARATION std::string to_string(const NLOHMANN_BASIC_JSON_TPL &j)
user-defined to_string function for JSON values
auto setAlpha(float alpha) -> OperationBuilder_v8 &
auto setdxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
bool is_pointwise_activation_bwd_op
bool is_pointwise_math_op
Operation_v8 & operator=(Operation_v8 const &)=delete
auto setbDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor wdesc
ManagedOpaqueDescriptor dxdesc
auto setaMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setmatmulDesc(MatMulDesc_v8 const &matmulDesc) -> OperationBuilder_v8 &
cudnnBackendDescriptorType_t op_mode
auto setBeta(float beta) -> OperationBuilder_v8 &
auto setpwDesc(PointWiseDesc_v8 const &pointWiseDesc) -> OperationBuilder_v8 &
auto setAlpha2(float alpha) -> OperationBuilder_v8 &
auto setdwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setreductionDesc(ReductionDesc_v8 const &reductionDesc) -> OperationBuilder_v8 &
int64_t pointwise_port_count
cudnnPointwiseMode_t pointwise_mode
cudnnStatus_t get_status() const
Current status of the descriptor.
auto setBeta(double beta) -> OperationBuilder_v8 &
auto setbMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor reductiondesc
Operation_v8(Operation_v8 &&from)
std::shared_ptr< OpaqueBackendPointer > ManagedOpaqueDescriptor
auto setdyDesc(ManagedOpaqueDescriptor const &raw_tensor) -> OperationBuilder_v8 &
std::string describe() const override
Return a string describing the backend Descriptor.
bool is_pointwise_activation_fwd_op
ManagedOpaqueDescriptor dwdesc
ManagedOpaqueDescriptor cmatdesc
auto setdyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
const char * get_error() const
Diagonistic error message if any.
ManagedOpaqueDescriptor bdesc
cudnnBackendAttributeType_t alphabetaType
ManagedOpaqueDescriptor xdesc
ManagedOpaqueDescriptor pwdesc
ManagedOpaqueDescriptor dydesc
auto setxDesc(ManagedOpaqueDescriptor const &raw_tensor) -> OperationBuilder_v8 &
auto setcMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor bmatdesc
auto setxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor cdesc
ManagedOpaqueDescriptor getOutputTensor()
auto setAlpha2(double alpha) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor matmuldesc
OperationBuilder_v8(cudnnBackendDescriptorType_t mode)
ManagedOpaqueDescriptor ydesc
std::string const & getTag() const
auto setAlpha(double alpha) -> OperationBuilder_v8 &
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.
ManagedOpaqueDescriptor pointer
ManagedOpaqueDescriptor amatdesc