Skip to content

Commit 6a6ae6d

Browse files
authored
Merge pull request #1312 from mi804/ltx2-iclora
Ltx2 iclora
2 parents 8fc7e00 + 1a380a6 commit 6a6ae6d

File tree

14 files changed

+516
-24
lines changed

14 files changed

+516
-24
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@ Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
645645
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
646646
|-|-|-|-|-|-|-|-|
647647
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
648+
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
649+
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
648650
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
649651
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
650652
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|

README_zh.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@ LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
645645
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
646646
|-|-|-|-|-|-|-|-|
647647
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
648+
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
649+
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
648650
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
649651
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
650652
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|

diffsynth/diffusion/base_pipeline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,23 @@ def to(self, *args, **kwargs):
9494
return self
9595

9696

97-
def check_resize_height_width(self, height, width, num_frames=None):
97+
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
9898
# Shape check
9999
if height % self.height_division_factor != 0:
100100
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
101-
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
101+
if verbose > 0:
102+
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
102103
if width % self.width_division_factor != 0:
103104
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
104-
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
105+
if verbose > 0:
106+
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
105107
if num_frames is None:
106108
return height, width
107109
else:
108110
if num_frames % self.time_division_factor != self.time_division_remainder:
109111
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
110-
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
112+
if verbose > 0:
113+
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
111114
return height, width, num_frames
112115

113116

diffsynth/pipelines/ltx2_audio_video.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
6161
LTX2AudioVideoUnit_InputAudioEmbedder(),
6262
LTX2AudioVideoUnit_InputVideoEmbedder(),
6363
LTX2AudioVideoUnit_InputImagesEmbedder(),
64+
LTX2AudioVideoUnit_InContextVideoEmbedder(),
6465
]
6566
self.model_fn = model_fn_ltx2
6667

@@ -105,18 +106,26 @@ def from_pretrained(
105106

106107
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
107108
if inputs_shared["use_two_stage_pipeline"]:
109+
if inputs_shared.get("clear_lora_before_state_two", False):
110+
self.clear_lora()
108111
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
109112
self.load_models_to_device('upsampler',)
110113
latent = self.upsampler(latent)
111114
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
112115
self.scheduler.set_timesteps(special_case="stage2")
113116
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
114117
denoise_mask_video = 1.0
118+
# input image
115119
if inputs_shared.get("input_images", None) is not None:
116120
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
117121
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
118122
inputs_shared["input_images_strength"], latent.clone())
119123
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
124+
# remove in-context video control in stage 2
125+
inputs_shared.pop("in_context_video_latents", None)
126+
inputs_shared.pop("in_context_video_positions", None)
127+
128+
# initialize latents for stage 2
120129
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
121130
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
122131
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
@@ -145,18 +154,22 @@ def __call__(
145154
# Prompt
146155
prompt: str,
147156
negative_prompt: Optional[str] = "",
148-
# Image-to-video
149157
denoising_strength: float = 1.0,
158+
# Image-to-video
150159
input_images: Optional[list[Image.Image]] = None,
151160
input_images_indexes: Optional[list[int]] = None,
152161
input_images_strength: Optional[float] = 1.0,
162+
# In-Context Video Control
163+
in_context_videos: Optional[list[list[Image.Image]]] = None,
164+
in_context_downsample_factor: Optional[int] = 2,
153165
# Randomness
154166
seed: Optional[int] = None,
155167
rand_device: Optional[str] = "cpu",
156168
# Shape
157169
height: Optional[int] = 512,
158170
width: Optional[int] = 768,
159171
num_frames=121,
172+
frame_rate=24,
160173
# Classifier-free guidance
161174
cfg_scale: Optional[float] = 3.0,
162175
# Scheduler
@@ -169,6 +182,7 @@ def __call__(
169182
tile_overlap_in_frames: Optional[int] = 24,
170183
# Special Pipelines
171184
use_two_stage_pipeline: Optional[bool] = False,
185+
clear_lora_before_state_two: Optional[bool] = False,
172186
use_distilled_pipeline: Optional[bool] = False,
173187
# progress_bar
174188
progress_bar_cmd=tqdm,
@@ -185,12 +199,13 @@ def __call__(
185199
}
186200
inputs_shared = {
187201
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
202+
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
188203
"seed": seed, "rand_device": rand_device,
189-
"height": height, "width": width, "num_frames": num_frames,
204+
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
190205
"cfg_scale": cfg_scale,
191206
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
192207
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
193-
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
208+
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two,
194209
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
195210
}
196211
for unit in self.units:
@@ -417,8 +432,8 @@ def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
417432
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
418433
def __init__(self):
419434
super().__init__(
420-
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
421-
output_params=("video_noise", "audio_noise",),
435+
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"),
436+
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
422437
)
423438

424439
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
@@ -471,7 +486,6 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled,
471486
if pipe.scheduler.training:
472487
return {"video_latents": input_latents, "input_latents": input_latents}
473488
else:
474-
# TODO: implement video-to-video
475489
raise NotImplementedError("Video-to-video not implemented yet.")
476490

477491
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
@@ -495,14 +509,13 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
495509
if pipe.scheduler.training:
496510
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
497511
else:
498-
# TODO: implement video-to-video
499-
raise NotImplementedError("Video-to-video not implemented yet.")
512+
raise NotImplementedError("Audio-to-video not supported.")
500513

501514
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
502515
def __init__(self):
503516
super().__init__(
504517
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
505-
output_params=("video_latents"),
518+
output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"),
506519
onload_model_names=("video_vae_encoder")
507520
)
508521

@@ -537,6 +550,54 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_index
537550
return output_dicts
538551

539552

553+
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
554+
def __init__(self):
555+
super().__init__(
556+
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
557+
output_params=("in_context_video_latents", "in_context_video_positions"),
558+
onload_model_names=("video_vae_encoder")
559+
)
560+
561+
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True):
562+
if in_context_video is None or len(in_context_video) == 0:
563+
raise ValueError("In-context video is None or empty.")
564+
in_context_video = in_context_video[:num_frames]
565+
expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor
566+
expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor
567+
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
568+
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
569+
if current_h != h or current_w != w:
570+
in_context_video = [img.resize((w, h)) for img in in_context_video]
571+
if current_f != f:
572+
# pad black frames at the end
573+
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
574+
return in_context_video
575+
576+
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True):
577+
if in_context_videos is None or len(in_context_videos) == 0:
578+
return {}
579+
else:
580+
pipe.load_models_to_device(self.onload_model_names)
581+
latents, positions = [], []
582+
for in_context_video in in_context_videos:
583+
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline)
584+
in_context_video = pipe.preprocess_video(in_context_video)
585+
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
586+
587+
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device)
588+
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
589+
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
590+
video_positions[:, 1, ...] *= in_context_downsample_factor # height axis
591+
video_positions[:, 2, ...] *= in_context_downsample_factor # width axis
592+
video_positions = video_positions.to(pipe.torch_dtype)
593+
594+
latents.append(in_context_latents)
595+
positions.append(video_positions)
596+
latents = torch.cat(latents, dim=1)
597+
positions = torch.cat(positions, dim=1)
598+
return {"in_context_video_latents": latents, "in_context_video_positions": positions}
599+
600+
540601
def model_fn_ltx2(
541602
dit: LTXModel,
542603
video_latents=None,
@@ -549,6 +610,8 @@ def model_fn_ltx2(
549610
audio_patchifier=None,
550611
timestep=None,
551612
denoise_mask_video=None,
613+
in_context_video_latents=None,
614+
in_context_video_positions=None,
552615
use_gradient_checkpointing=False,
553616
use_gradient_checkpointing_offload=False,
554617
**kwargs,
@@ -558,16 +621,25 @@ def model_fn_ltx2(
558621
# patchify
559622
b, c_v, f, h, w = video_latents.shape
560623
video_latents = video_patchifier.patchify(video_latents)
624+
seq_len_video = video_latents.shape[1]
561625
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
562626
if denoise_mask_video is not None:
563627
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
628+
629+
if in_context_video_latents is not None:
630+
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)
631+
in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0.
632+
video_latents = torch.cat([video_latents, in_context_video_latents], dim=1)
633+
video_positions = torch.cat([video_positions, in_context_video_positions], dim=2)
634+
video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1)
635+
564636
if audio_latents is not None:
565637
_, c_a, _, mel_bins = audio_latents.shape
566638
audio_latents = audio_patchifier.patchify(audio_latents)
567639
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
568640
else:
569641
audio_timesteps = None
570-
#TODO: support gradient checkpointing in training
642+
571643
vx, ax = dit(
572644
video_latents=video_latents,
573645
video_positions=video_positions,
@@ -580,6 +652,8 @@ def model_fn_ltx2(
580652
use_gradient_checkpointing=use_gradient_checkpointing,
581653
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
582654
)
655+
656+
vx = vx[:, :seq_len_video, ...]
583657
# unpatchify
584658
vx = video_patchifier.unpatchify_video(vx, f, h, w)
585659
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None

0 commit comments

Comments
 (0)