-
Notifications
You must be signed in to change notification settings - Fork 1k
Description
Traceback (most recent call last):
File "train_wan_video.py", line 268, in
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
File "diffsynth/diffusion/runner.py", line 34, in launch_training_task
loss = model(data)
File "train_wan_video.py", line 130, in forward
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
File "diffsynth/pipelines/wan_video.py", line 102, in process
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, ...)
File "diffsynth/models/wan_video_vae.py", line 512, in encode
hidden_state = self.single_encode(video, device)
File "diffsynth/models/wan_video_vae.py", line 545, in single_encode
# Here i=0, but self._enc_feat_map contains data from the PREVIOUS batch!
out = self.encoder(x, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "diffsynth/models/wan_video_vae.py", line 980, in forward
# Inside Encoder Loop
x = layer(x, feat_cache, feat_idx)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "diffsynth/models/wan_video_vae.py", line 1150, in forward
# Inside ResBlock or DownSample Block
x = layer(x, feat_cache[idx])
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "diffsynth/models/wan_video_vae.py", line 1228, in forward
# CRASH HAPPENS HERE
# cache_x (from dirty cache) has 384 channels (Deep Feature)
# x (current input) has 96 channels (Shallow Feature)
x = torch.cat([cache_x, x], dim=2)
RuntimeError: sizes of tensors must match except in dimension 2. Expected size 384 but got size 96 for tensor number 1 in the list.
我训练wan2.1的时候发生了这个错误,我已经看过输入的input_video了是(1,3,121,480,832),应该没有问题,更进一步,我单独写了一个debug脚本测试了单独加载数据集和调用vae encode,,结果跑通了,没有出现问题,这是为什么,该怎么解决呢