You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.1 KiB
46 lines
1.1 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Wed Feb 3 11:25:23 2021
|
|
|
|
@author: Liangchao Zhu
|
|
"""
|
|
from torch import nn
|
|
from torch.autograd import Function
|
|
import torch
|
|
|
|
import feconv_cuda
|
|
|
|
'''
|
|
#forward#
|
|
INPUT:
|
|
* U : [batch_size,18,41,41,41]
|
|
* H8types : [batch_size, 1,41,41,41]
|
|
* nodIdx : [41,41,41,27,3]
|
|
* filters : [2^8, 3*3,27]
|
|
|
|
OUTPUT:
|
|
* KU : [batch_size,18,41,41,41]
|
|
'''
|
|
|
|
class FEconvFunction(Function):
|
|
@staticmethod
|
|
def forward(ctx,U,H8types,nodIdx,filters):
|
|
outputs = feconv_cuda.forward(U,H8types,nodIdx,filters)
|
|
KU = outputs[0]
|
|
ctx.save_for_backward(*outputs)
|
|
return KU
|
|
@staticmethod
|
|
def backward(ctx,gradU):
|
|
outputs = feconv_cuda.backward(gradU,*ctx.saved_variables)
|
|
return outputs[0]
|
|
# https://blog.csdn.net/littlehaes/article/details/103828130
|
|
# ctx.save_for_backward(a, b)能够保存forward()静态方法中的张量,
|
|
# 从而可以在backward()静态方法中调用,
|
|
# 具体地, 下面地代码通过a, b = ctx.saved_tensors重新得到a和b
|
|
|
|
|
|
class FECONV(nn.Module):
|
|
def __init__(self):
|
|
super(FECONV,self).__init__()
|
|
def forward(self,U,H8types,nodIdx,filters):
|
|
return FEconvFunction.apply(U,H8types,nodIdx,filters)
|