File size: 457 Bytes
39b2670
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin

class CCProjection(ModelMixin, ConfigMixin):
    def __init__(self, in_channel=1028, out_channel=1024):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.projection = torch.nn.Linear(in_channel, out_channel)

    def forward(self, x):
        return self.projection(x)