Skip to content

wan训练时vae encode错误 #1113

@Pensioner11

Description

@Pensioner11

Traceback (most recent call last):
File "train_wan_video.py", line 268, in
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
File "diffsynth/diffusion/runner.py", line 34, in launch_training_task
loss = model(data)
File "train_wan_video.py", line 130, in forward
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
File "diffsynth/pipelines/wan_video.py", line 102, in process
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, ...)
File "diffsynth/models/wan_video_vae.py", line 512, in encode
hidden_state = self.single_encode(video, device)
File "diffsynth/models/wan_video_vae.py", line 545, in single_encode
# Here i=0, but self._enc_feat_map contains data from the PREVIOUS batch!
out = self.encoder(x, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "diffsynth/models/wan_video_vae.py", line 980, in forward
# Inside Encoder Loop
x = layer(x, feat_cache, feat_idx)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "diffsynth/models/wan_video_vae.py", line 1150, in forward
# Inside ResBlock or DownSample Block
x = layer(x, feat_cache[idx])
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "diffsynth/models/wan_video_vae.py", line 1228, in forward
# CRASH HAPPENS HERE
# cache_x (from dirty cache) has 384 channels (Deep Feature)
# x (current input) has 96 channels (Shallow Feature)
x = torch.cat([cache_x, x], dim=2)
RuntimeError: sizes of tensors must match except in dimension 2. Expected size 384 but got size 96 for tensor number 1 in the list.

我训练wan2.1的时候发生了这个错误,我已经看过输入的input_video了是(1,3,121,480,832),应该没有问题,更进一步,我单独写了一个debug脚本测试了单独加载数据集和调用vae encode,,结果跑通了,没有出现问题,这是为什么,该怎么解决呢

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions