You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

46 lines
1.1 KiB

# -*- coding: utf-8 -*-
"""
Created on Wed Feb 3 11:25:23 2021
@author: Liangchao Zhu
"""
from torch import nn
from torch.autograd import Function
import torch
import feconv_cuda
'''
#forward#
INPUT:
* U : [batch_size,18,41,41,41]
* H8types : [batch_size, 1,41,41,41]
* nodIdx : [41,41,41,27,3]
* filters : [2^8, 3*3,27]
OUTPUT:
* KU : [batch_size,18,41,41,41]
'''
class FEconvFunction(Function):
@staticmethod
def forward(ctx,U,H8types,nodIdx,filters):
outputs = feconv_cuda.forward(U,H8types,nodIdx,filters)
KU = outputs[0]
ctx.save_for_backward(*outputs)
return KU
@staticmethod
def backward(ctx,gradU):
outputs = feconv_cuda.backward(gradU,*ctx.saved_variables)
return outputs[0]
# https://blog.csdn.net/littlehaes/article/details/103828130
# ctx.save_for_backward(a, b)能够保存forward()静态方法中的张量,
# 从而可以在backward()静态方法中调用,
# 具体地, 下面地代码通过a, b = ctx.saved_tensors重新得到a和b
class FECONV(nn.Module):
def __init__(self):
super(FECONV,self).__init__()
def forward(self,U,H8types,nodIdx,filters):
return FEconvFunction.apply(U,H8types,nodIdx,filters)