HW10:尝试使用 LoRA 微调 Stable Diffusion 模型(文生图)

   日期:2024-12-26    作者:b1246808 移动:http://ljhr2012.riyuangf.com/mobile/quote/55617.html

目录

0 前言

1 前言

2 开始动手

3 安装必要的库

4 导入

5 准备数据

6 设置项目路径

7 导入数据

7.1 怎么扩充数据集

7.2 怎么让模型理解文本

具体解释

7.3 自定义数据集

8 定义微调相关的函数

8.1 加载 LoRA

8.2 准备优化器

8.3 定义 collate_fn 函数

9 设置相关参数

9.1 设备配置

9.2 模型与训练参数配置

10 微调前的准备

10.1 准备数据集

10.2 准备模型和优化器

11 开始微调

12 生成图像和评估

12.1 什么是 pipeline

12.2 推理相关的参数

12.3 加载用于验证的 prompts

12.4 定义生成图像的函数

12.5 定义评估函数

13 拓展作业

14 用脚本微调 SD(可选

14.1 克隆仓库

14.2 执行脚本

15 参考链接


本文为李宏毅学习笔记——2024春《GENERATIVE AI》篇——作业笔记HW10。

如果你还没获取到LLM API,请查看我的另一篇笔记

HW1~2:LLM API获取步骤及LLM API使用演示:环境配置与多轮对话演示-CSDN博客

完整内容参见

李宏毅学习笔记——2024春《GENERATIVE AI》篇

总得拆开炼丹炉看看是什么样的。这篇文章将带你从代码层面一步步实现 AI 文本生成图像(Text-to-Image)中的 LoRA 微调过程,你将

  • 了解 Trigger Words(触发词)到底是什么,以及它们如何影响生成结果。
  • 掌握 LoRA 微调的基本原理。
  • 学习数据集的准备与结构,并知道如何根据需求定制自己的数据集。
  • 理解 Stable Diffusion 模型的微调步骤。
  • 明白在画图界面(UI)下到底发生了什么。
  • 使用代码实现 AI 绘画。

如果你想制作属于自己的数据集,最好遵循以下建议

  1. 至少准备 20 张图片:想学到的概念越复杂就需要越多的图片。你可以尝试将样例数据集的图片数量减少到 20 张,看看效果会有什么变化。
  2. 裁剪图片:建议对图片进行裁剪,当然你也可以不裁剪,如果你不追求效果的话。这里会自动 resize 到自定义的分辨率。

与其花费大量时间去调参,更优的选择是处理好你的数据集和 Prompts。当然,这两件事情可以同步进行。

注意,当前文章使用的是自然语言标注(而非 Tag,你也可以使用 Tag,这两种方式本质上是一致的。

同时,如果你对深度学习有所了解,那么代码中的一切,都将是你曾经见过的内容翻版,没有什么新的,除了 LoRA。我们将同步使用演员 Brad Pitt(布拉德·皮特)的图片作为训练集,共计一百张。

下面是使用 prompt,在默认设置下训练 2000 个步骤后模型生成的图像,训练时长约为 18 分钟。乍一看,是不是还挺不错的

你可能会注意到,我们的 prompt 中并没有提到 Brad Pitt(布拉德·皮特)这个演员(尽管我们的数据集完全来自于他,但模型却能够绘制长得像 Brad Pitt 的人。

这是因为,如果我们在 prompt 中直接指定 "Brad Pitt",模型可能无法完全学习到他的特征风格。举个例子

  • "A man in a graphic tee and sport coat. Brad Pitt."
  • "A man in a graphic tee and sport coat."

第一条 prompt 显然更精准,但精准并不意味着模型训练得更好。如果你用一系列包含 "Brad Pitt" 的 prompt 来训练,模型更有可能学到的是:只有在加上 "Brad Pitt" 时才进行风格转变。你可能会说:“我就是想要这个效果”,那么很好"Brad Pitt" 就是你模型的 Trigger Word(触发词)。但有可能还有同学:“我希望模型只为 Brad Pitt 服务,我要把所有的 'man' 都变成 Brad Pitt”,那么在训练时就不要在 prompt 中增加 "Brad Pitt"。简而言之反着来

这实际上并没有反直觉,跳出来想一想

  1. 想象一下你是一位画家,生活在一个从不变暗的世界里,整个世界永远是白天,你已经习惯画出白天背景下的各种景象,但你不知道白天是什么,这就是你所熟知的「日常」。

  2. 有一天,有人给你看了一些照片,说:“Hey,实际上世界可以是黑的,叫做夜晚”,这时候你就会理解到,日常是有另一种状态的,叫做夜晚,即便你以前从来没有过概念,但现在,你将认知到它,你将这部分新的概念聚焦到了「夜晚」。于是,从此以后,你的画作被分为了「日常」和「日常,夜晚」。

  3. 同时,在另一个平行世界,有人告诉你:“你眼中看到的世界是不对的”,他们“治”好了你的眼睛,向你展示了一个完全陌生的漆黑世界,并承诺只要你学会画出这种风格的画作,将会获得丰厚的回报,否则将无人问津你的画摊。于是你开始画“夜晚”风格的「日常」。

这是杜攥的三个小片段,希望你喜欢。

你可以分别将它理解为

  1. 原始模型:活在自己世界的画家。
  2. LoRA 微调:当新标签(Tag)“夜晚”被引入,画家学会了夜晚的概念。Prompt:夜晚,日常。
  3. 另一个 LoRA 微调:迁移风格,画家将“夜晚”视为真正的日常风格。Prompt:日常。

因此,训练模型就像教小朋友认知世界。如果你将世界分解为不同的概念并逐一传授,孩子会学到不同的知识。这就类似于模型学习不同的标签和风格。如果你不明确区分概念,并将新概念混杂在已有的认知中,孩子的认知会被重塑,或许会将鹿“误”认为马。这是合理的,模型也是如此,取决于你如何教导(prompt)它。

Prompt 小技巧

  • 明确你的目标:在训练前,思考你是希望模型学习特定的风格、特定的人物,还是希望模型在特定的场景下才生成特定的效果。到底是希望所有的 man 都是 Brad Pitt,还是希望模型知道 Brad Pitt 是一个 man。
  • 保持一致性:如果你希望将某个概念拆分出来,应该为它创建一个特定的标签(tag,并应用于具有相同概念的图像上。

大模型很聪明,它会自动将图像中的共性归因于共用的标签上。因此,如果不给它新的标签,它会将新学到的内容融入到已有的标签中。

这些是关于 AI 绘画 Prompt + 微调背后逻辑的大白话。扯远了,让我们回到代码部分 :)

下面,我将带你从代码层面一步步实现 LoRA 微调 Stable Diffusion 模型。注意,这里的知识是通用的,你完全可以推广至任何需要 LoRA 微调的领域。

首先,确保安装以下必要的 Python 库

 
 
 
 

当前演示使用的是 Brad Pitt(布拉德·皮特,我们的目标是让模型绘制的 man 是 Brad Pitt,粗略地换个表述:AI 换脸。

那根据我们之前的描述,标注应该长什么样呢

:都带 “man”,下面是我们当前数据集的标注示例

  1. a man with a beard and a suit jacket
  2. a man in a suit and tie standing in front of a crowd
  3. a man with long hair and a tie
  4. ...

相信你发现了,所有的标注,都不会含有 “Brad Pitt”,那这篇文章训练出的 LoRA 模型的 Trigger Words(触发词)是什么

:“a man”。

是不是很有趣,看似简单的 Prompt 中也有一些真实有用的小技巧和逻辑。别急着去炼丹,我们继续往下看。

在这里,我们使用 Brad Pitt 的 100 张图片进行演示,数据集已经上传到了Demos/data/14,你可以下载后放到当前目录下的  下。这个路径没有什么说法,单纯是为了对齐示例代码,你也可以修改代码关于数据的路径,这里不会有限制,你甚至可以直接用其他的数据集,只要它的文件组织如下

-- 图片1
-- 图片1.txt
-- 图片2
-- 图片2.txt
...

注意:图片和对应的文本标注需要同名,且位于同一文件夹中。

值得一提的是,样例数据集的裁剪大小和比例都是不一致的,只是接近正方形,但这没有太大的关系,因为在数据预处理的时候会自动放缩(resize,所以在这里不用担心你的数据集无法训练。

很好!现在你已经知道这篇文章数据集相关的所有前置知识,直接复制下面的代码运行,不用在意其中的任何代码细节,你只需要知道会创建一个文件夹,之后的所有结果都会被存放在其中

 
 

下面,我们需要自定义一个  类,它的作用是告诉模型如何处理你的数据集,这个自定义的类能够返回图像和文本标注分别作为  和 。接下来的内容会有点“干”,你也可以将其先当作黑盒,我会在每个函数之后提供一个简练的解释帮你理解。

拓展文章:e. 数据增强:torchvision.transforms 常用方法解析

这里有一个非常熟悉的词,但这个跟我们耳熟能详的  可不同, 就是单纯的对图像进行操作,比如说调整大小,翻转,又或者随机的裁剪一部分区域,这些操作统称为数据增强。

数据增强就是扩充数据集的外挂,以下图为例,即便进行水平翻转+颜色变化+中心裁剪,它也是一只企鹅。

这大大地扩充了数据集。

知道了概念后,简单定义当前的数据增强如下

 
 

使用 CLIPTokenizer,这是 Hugging Face transformers 库中的一个类,专门用于对文本进行分词(tokenization)操作。CLIP,全称 Contrastive Language-Image Pretraining(对比语言-图像预训练,Contrastive 这个词说透了它的由来,这是一个非常有意思的自训练思想:通过最大化对应文本-图像对的相似性,同时最小化不同文本-图像对的相似性实现训练。

学习资料

论文链接:Learning Transferable Visual Models From Natural Language Supervision 对理论感兴趣的话可以进一步查看以下四个非常棒的视频

  1. 对比学习论文综述【论文精读】
  2. CLIP 论文逐段精读【论文精读】
  3. CLIP 改进工作串讲(上)【论文精读·42】
  4. CLIP 改进工作串讲(下)【论文精读·42】

你将发现两个宝藏 UP 主,我无法用语言表达对他们的赞美,只能道一句:“导师好!”。

具体来说, 将输入的 prompt 拆解为 token(单词或子词,并将这些 token 映射为 供 CLIP 模型的  处理,从而生成 prompt 的嵌入向量,以让模型理解。

就像一切数据到了计算机中都变成 0,1 让其处理,所以向上抽象一下 就是将人类可以阅读的文本描述变成模型能够理解的形式。

拓展:看看 Tokenizer 实际上做了什么

 
 

输出

 
 

:49407 是什么?我们的 prompt 中似乎没有重复的词。

:结束标记,这是因为我们设置了 。思考一下,设置后输出应该是什么样的?先不要往下滑。

具体解释
  • Tokenized Input IDs:这个张量展示了输入文本  被转换为的数字 ID 序列。每个数字 ID 对应于词汇表中的一个 token, 是起始标记, 是结束标记。
  • Attention Mask:用于标记哪些 token 需要模型的关注,1 表示有效 token,0 表示填充的无效 token。

时的输出

 
 

是不是和预期一致呢

接下来, 将被传入 ,生成文本的嵌入向量。

在认识  和  之后,我们可以定义自己的数据集。这个  负责将图像和文本配对,并进行数据的预处理,以便输入到模型中。

 

解释

  • :定义可接受的图像文件扩展名列表。
  •  方法
    • 图像路径:通过遍历指定的图像文件夹,获取所有符合扩展名的图像文件路径,并排序。
    • 文本标注:在标注文件夹中查找所有  文件,读取其内容并存储为列表。
    • 一致性检查:确保图像数量与文本标注数量一致。
    • 文本编码:使用  将文本标注转换为 token IDs。
    • 数据转换:存储图像的预处理方法 。
  •  方法
    • 根据索引获取图像路径和对应的文本 token ID。
    • 尝试加载并预处理图像,失败时返回全零张量。
  •  方法:返回数据集的长度。

前置文章

  • HW4-补充1:认识 LoRA:从线性层到注意力机制-CSDN博客
  • HW8-补充1:在大模型中快速应用 LoRA-CSDN博客

LoRA(Low-Rank Adaptation 是一种非常高效的参数微调方法,通过在预训练模型的特定层添加小的低秩矩阵(可以联想线性代数中的奇异值分解,来实现模型的微调,这也是一类 Adapter。

LoRA 的核心思想是将大模型中的某些权重矩阵近似为两个低秩矩阵进行更新,从而大幅减少需要微调的参数数量,提高训练效率和节省存储空间。一般而言,模型越大,减小比例越夸张,对于 GPT-3,LoRA 微调的训练参数量为原来的 1/10000。

通常,在微调时我们只对模型的特定部分(如注意力机制中的 Q、K、V 矩阵)进行 LoRA 微调,而不是微调整个模型。这里选择对  和  增加 LoRA,因为这两个模块直接负责图像生成和文本引导中的关键任务: 处理扩散过程的逆运算, 将输入文本转换为特征向量。下面,我们定义一个函数来应用 LoRA 模型。

 

解释

  • 加载模型组件 依次加载了噪声调度器、Tokenizer、文本编码器)、VAE 和 UNet 模型。
  • 应用 LoRA 使用  函数将 LoRA 配置应用到  和  模型中。这会在模型中插入可训练的 LoRA 层。
  • 打印可训练参数 调用  来查看 LoRA 添加了多少可训练参数。
  • 恢复训练 如果设置了 ,则从指定的  加载之前保存的模型权重。
  • 合并 LoRA 权重 如果 ,则将 LoRA 的权重合并到基础模型中,以便在推理时使用。
  • 冻结 VAE 参数 调用  来冻结 VAE 的参数,使其在训练中不更新。
  • 移动模型到设备 将所有模型组件移动到指定的设备(CPU 或 GPU,并设置数据类型。

为什么只微调  和  最终却返回这么多模块

因为在后面的微调中,我们将从文本开始处理而非将其当作又一个黑盒。

接下来,需要对于应用了 LoRA 的 UNet 和文本编码器)分别使用不同的学习率,这也是炼丹炉 UI 中常需要调节的选项。

 
 

在大多数常见的机器学习任务中(例如图像分类或回归,数据集通常是简单的  结构,PyTorch 的  默认能够处理这样的简单数据结构,将样本打包成批次(batch)。在我们的项目中,每个样本也是一个包含图像张量和文本编码的元组 。默认的  可以将这些样本打包成批次,访问时需要使用索引,例如  和 。

为了使代码更具可读性,我们可以自定义一个  函数,将批次数据组织成字典的形式,方便通过键名直接访问,例如  和 。自定义的  定义如下

 

解释

  •  是什么
    •  是一个列表,包含了一个批次中的多个样本。
    • 其中的每个样本都是从我们自定义的  数据集中获取的,形式为 。
      • :经过预处理的图像张量,形状为 ,即通道数(Channel)和图像的高度(Height)、宽度(Weight)。
      • :对应的文本标注经过  编码后的张量,形状为 。

补充:PyTorch 的  函数会将多个张量沿新维度拼接在一起。例如,将一批图像张量拼接成  的形式,确保每个批次数据的组织结构一致。

拓展:自定义和默认  的对比

下面提供了一个对比函数,来展示自定义  和默认  在处理当前数据时的不同。你可以通过运行代码来观察自定义和默认方式的使用差异。

 
 

输出

 
 

具体选择哪一种由你决定,默认的方法实际上更普遍。

当前的微调毫无疑问需要用到显卡(GPU,对于 Apple 芯片的 Mac 来说,把 "cuda" 改为 "mps",也就是使用第二行代码,但需要注意的是,对于PyTorch版本过低的环境,  会报错,所以这里选择注释。

 
 

这里的参数大多与之前的函数相关,下面是你可以调节的内容

  • 训练参数:设置批次大小、数据类型、随机种子等。
    •  时,微调显存要求为 5G,在命令行输入  可以查看当前显存占用。
  • 优化器参数:为 UNet 和文本编码器分别设置学习率。
  • 学习率调度器:选择  调度器,这一点一般无关紧要。
  • 预训练模型:指定预训练的 Stable Diffusion 模型。
  • LoRA 配置:设置 LoRA 的相关参数,如秩 、、应用模块等。
 
 
 
 

解释

  • 加载 Tokenizer 使用与预训练模型相同的 Tokenizer。
  • 创建数据集 使用我们之前定义的 。
  • 创建数据加载器 使用 PyTorch 的 。
 

解释

  • 准备模型 调用之前定义的  函数。
  • 准备优化器 调用之前定义的  函数。
  • 设置学习率调度器 使用 Hugging Face 的  函数。

主要流程和结构如下

  • 训练循环 我们在多个  中进行训练,直到达到 。每个  代表一轮数据的完整训练,在常见的 UI 界面中也可以看到  和  的参数。
  • 编码图像 使用 VAE(变分自编码器)将图像编码为潜在表示(latent space,以便后续在扩散模型中添加噪声并进行处理。
  • 添加噪声 使用噪声调度器)为潜在表示添加随机噪声,模拟图像从清晰到噪声的退化过程。这是扩散模型的关键步骤,训练时模型通过学习如何还原噪声,从而在推理过程中通过逐步去噪生成清晰的图像。
  • 获取文本嵌入 使用文本编码器)将输入的文本 prompt 转换为隐藏状态(我们见过很多类似的表达:隐藏向量/特征向量/embedding/...,为图像生成提供文本引导信息。
  • 计算目标值 根据扩散模型的类型( 或 ,确定模型的目标输出(噪声或速度向量)。
  • UNet 预测 使用 UNet 模型对带噪声的潜在表示进行预测,生成的输出用于还原噪声或预测速度向量。
  • 计算损失 通过加权均方误差(MSE)计算模型损失,并进行反向传播。
  • 优化与保存:通过优化器更新模型参数,并在适当时保存检查点。
 

训练完成后的  会保存到  中,以  为例,模型输出如下

 是 Hugging Face 库中一种高层次的封装工具,通常用于推理。默认情况下, 以 eval 模式加载模型,因此适合用于生成或评估场景。我们这里使用的是 ,它将之前提到的多个模型组件(如 UNet、VAE、文本编码器等)组合在一起,实现从文本到图像的生成。

 的工作原理也跟之前微调过程类似

  1. 文本编码: 中的文本编码器会将输入的  转换为特征向量。
  2. 噪声注入:在潜在空间中,模型从随机噪声开始生成图像。
  3. 迭代去噪:UNet 使用从文本编码器得到的特征向量指导去噪过程,逐步将噪声还原为高质量图像。
  4. 图像解码:最终,VAE 将潜在表示解码为实际的图像。
  1. 什么是推理步数

    推理步数控制扩散模型生成图像时的去噪迭代次数。步数越多,生成的图像质量越高,但推理时间也相应增加。这是一个需要你根据图像质量和时间需求去权衡的参数,通常在肉眼觉得够好的时候,就可以了。

  2. 如何决定  的影响程度

     决定了文本提示对生成图像的影响程度。较高的  会让模型更严格地按照  生成图像,数值通常在 7.5 到 10 之间调整,过高可能会导致图像失真,同样需要你去权衡。这个参数与文本生成任务中的  参数类似,适用于不同场景。

  3. 怎么确保相同  生成相同的图像

    设置固定的随机数种子(seed,可以确保同样的  在每次运行时生成相同的图像。可以通过使用  生成随机数并设置种子(seed,示例如下

    generator = torch.Generator().manual_seed(42)

这是一组用于生图的文本提示(prompts,本实验中位于,下面摘取几行 prompt 预览

  • A man in a black hoodie and khaki pants.
  • A man sports a red polo and denim jacket.
  • A man wears a blue shirt and brown blazer.
  • ...

定义加载  的函数如下

 
 

结合之前的讨论,我们可以定义一个生成图像的函数

 
 

虽然图像生成的好与坏现在更多的由人去判断,但最基础的模块还是可以交给机器,以当前实验为例,我们的目的是 “AI 换脸”,那就可以有两个新的度量

  • 无脸图像的数量

    使用  库检测生成图像中的人脸。如果没有检测到人脸,则该图像计为无脸图像,数量加 1。
  • 面部相似性

    利用  库提取生成图像中的人脸特征,然后与训练集中人脸的特征进行对比。通过计算欧氏距离来衡量相似度,距离越小,表示生成的人脸与训练集中人脸的相似度越高。

除了人脸生成之外,AI 图像生成领域还有很多其他应用场景。那么,有没有通用的评估方法来衡量生成图像与文本提示的匹配度呢

CLIP 评分

是的,CLIP 除了可以处理文本输入,还可以评估最终的模型,无论生成的是人脸、风景还是物体,它都可以帮助我们判断生成图像与文本提示的相关性。

对于当前实验,我们采取这三种方式对模型进行度量,完整流程如下

  1. 使用  函数从文件中加载 prompts。
  2. 使用  函数加载已经经过 LoRA 微调的 UNet 和文本编码器,并合并 LoRA 权重。模型会从上一次训练保存的文件中恢复权重。
  3. 使用已经微调的 UNet 和文本编码器来创建 。
  4. 加载 CLIP 模型后续用于评估。
  5. 使用  提取训练图像的面部嵌入  与生成的图像进行对比,计算面部相似度。
  6. 进行评估,最后打印结果。
 
 

 

生成的图像会保存在  中。 

  1. 当前 prompt 的触发词trigger words)只是 “a man” 吗
    仔细观察之前数据集的prompt

    • a man with a beard and a suit jacket
    • a man in a suit and tie standing in front of a crowd
    • a man with long hair and a tie
    • ...
  2. 使用当前数据集训练出的模型,如果 prompt 设置为 “a man”,生成的图像应该是什么样的

  3. 除了之前设置的参数外,探究生成图像相关参数(位于 )。

 

希望你能通过对代码文件的运行,找到它们的答案。 

这是可选的行为,脚本的代码处理逻辑与文章对应。

 
 
  1. 切换到  文件夹

     
  2. 准备样例数据集

    # 如果已经下载过,可以跳过,将之后的命令参数修改为对应路径
    wget https://github.com/Hoper-J/AI-Guide-and-Demos-zh_CN/raw/refs/heads/master/Demos/data/14/Datasets.zip
    unzip Datasets.zip
    
  3. 使用指定的数据集和提示文件

    python sd_lora.py -d https://blog.csdn.net/a131529/article/details/Datasets/Brad -gp https://blog.csdn.net/a131529/article/details/Datasets/prompts/validation_prompt.txt
    •  或 :数据集路径。
    •  或 :生成图像时使用的文本提示文件路径。
  4. 指定其他参数

    python sd_lora.py -d https://blog.csdn.net/a131529/article/details/Datasets/Brad -gp https://blog.csdn.net/a131529/article/details/Datasets/prompts/validation_prompt.txt -e 500 -b 4 -u 1e-4 -t 1e-5
    •  或 :总训练步数。
    •  或 :训练批次大小。
    •  或 :UNet 的学习率。
    •  或 :文本编码器的学习率。
    • 其他参数使用  进行查看。

特别提示:本信息由相关用户自行提供,真实性未证实,仅供参考。请谨慎采用,风险自负。


举报收藏 0评论 0
0相关评论
相关最新动态
推荐最新动态
点击排行
{
网站首页  |  关于我们  |  联系方式  |  使用协议  |  隐私政策  |  版权隐私  |  网站地图  |  排名推广  |  广告服务  |  积分换礼  |  网站留言  |  RSS订阅  |  违规举报  |  鄂ICP备2020018471号