2.2 将数据输入模型
书接上回,我们将二进制文件转换为了 .png 图片以及 label.txt, 现在我们要把这些图片输入到模型中。
因此我们构造了一个 class customDataset 来导入数据(其实就是放到 label[] 和 img[] 里)。
(资料图片仅供参考)
首先是初始化,在初始化的部分会导入图片的路径和标签
其次是将图片通过 CV2 导入
最后是后面的函数需要的 len 函数
3. 构建模型
3.1. 生成器 generator
此处使用的是 Multi-LayerPerception 的全连接层来链接不同的层。
由于图片是 28×28 的格式,所以最后应该是一个1×784 的层。
由此可以得出整个模型的形式 : latent dim -> 1024 -> 784
而对于每一层,形式是:nn.linear → normalize → 激活函数
(此处为Leaky ReLU)
于是可以写出初始化的代码:
之后就是 forward 函数:
3.2. 鉴别器 discriminator
鉴别器和生成器结构非常相似,只是反过来而已。
至此,我们已经完成了GAN的大部分内容,包括数据输入、生成器和鉴别器,下一步就是训练模型了。
标签: