@@ -61,6 +61,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
6161 LTX2AudioVideoUnit_InputAudioEmbedder (),
6262 LTX2AudioVideoUnit_InputVideoEmbedder (),
6363 LTX2AudioVideoUnit_InputImagesEmbedder (),
64+ LTX2AudioVideoUnit_InContextVideoEmbedder (),
6465 ]
6566 self .model_fn = model_fn_ltx2
6667
@@ -105,18 +106,26 @@ def from_pretrained(
105106
106107 def stage2_denoise (self , inputs_shared , inputs_posi , inputs_nega , progress_bar_cmd = tqdm ):
107108 if inputs_shared ["use_two_stage_pipeline" ]:
109+ if inputs_shared .get ("clear_lora_before_state_two" , False ):
110+ self .clear_lora ()
108111 latent = self .video_vae_encoder .per_channel_statistics .un_normalize (inputs_shared ["video_latents" ])
109112 self .load_models_to_device ('upsampler' ,)
110113 latent = self .upsampler (latent )
111114 latent = self .video_vae_encoder .per_channel_statistics .normalize (latent )
112115 self .scheduler .set_timesteps (special_case = "stage2" )
113116 inputs_shared .update ({k .replace ("stage2_" , "" ): v for k , v in inputs_shared .items () if k .startswith ("stage2_" )})
114117 denoise_mask_video = 1.0
118+ # input image
115119 if inputs_shared .get ("input_images" , None ) is not None :
116120 latent , denoise_mask_video , initial_latents = self .apply_input_images_to_latents (
117121 latent , inputs_shared .pop ("input_latents" ), inputs_shared ["input_images_indexes" ],
118122 inputs_shared ["input_images_strength" ], latent .clone ())
119123 inputs_shared .update ({"input_latents_video" : initial_latents , "denoise_mask_video" : denoise_mask_video })
124+ # remove in-context video control in stage 2
125+ inputs_shared .pop ("in_context_video_latents" , None )
126+ inputs_shared .pop ("in_context_video_positions" , None )
127+
128+ # initialize latents for stage 2
120129 inputs_shared ["video_latents" ] = self .scheduler .sigmas [0 ] * denoise_mask_video * inputs_shared [
121130 "video_noise" ] + (1 - self .scheduler .sigmas [0 ] * denoise_mask_video ) * latent
122131 inputs_shared ["audio_latents" ] = self .scheduler .sigmas [0 ] * inputs_shared ["audio_noise" ] + (
@@ -145,18 +154,22 @@ def __call__(
145154 # Prompt
146155 prompt : str ,
147156 negative_prompt : Optional [str ] = "" ,
148- # Image-to-video
149157 denoising_strength : float = 1.0 ,
158+ # Image-to-video
150159 input_images : Optional [list [Image .Image ]] = None ,
151160 input_images_indexes : Optional [list [int ]] = None ,
152161 input_images_strength : Optional [float ] = 1.0 ,
162+ # In-Context Video Control
163+ in_context_videos : Optional [list [list [Image .Image ]]] = None ,
164+ in_context_downsample_factor : Optional [int ] = 2 ,
153165 # Randomness
154166 seed : Optional [int ] = None ,
155167 rand_device : Optional [str ] = "cpu" ,
156168 # Shape
157169 height : Optional [int ] = 512 ,
158170 width : Optional [int ] = 768 ,
159171 num_frames = 121 ,
172+ frame_rate = 24 ,
160173 # Classifier-free guidance
161174 cfg_scale : Optional [float ] = 3.0 ,
162175 # Scheduler
@@ -169,6 +182,7 @@ def __call__(
169182 tile_overlap_in_frames : Optional [int ] = 24 ,
170183 # Special Pipelines
171184 use_two_stage_pipeline : Optional [bool ] = False ,
185+ clear_lora_before_state_two : Optional [bool ] = False ,
172186 use_distilled_pipeline : Optional [bool ] = False ,
173187 # progress_bar
174188 progress_bar_cmd = tqdm ,
@@ -185,12 +199,13 @@ def __call__(
185199 }
186200 inputs_shared = {
187201 "input_images" : input_images , "input_images_indexes" : input_images_indexes , "input_images_strength" : input_images_strength ,
202+ "in_context_videos" : in_context_videos , "in_context_downsample_factor" : in_context_downsample_factor ,
188203 "seed" : seed , "rand_device" : rand_device ,
189- "height" : height , "width" : width , "num_frames" : num_frames ,
204+ "height" : height , "width" : width , "num_frames" : num_frames , "frame_rate" : frame_rate ,
190205 "cfg_scale" : cfg_scale ,
191206 "tiled" : tiled , "tile_size_in_pixels" : tile_size_in_pixels , "tile_overlap_in_pixels" : tile_overlap_in_pixels ,
192207 "tile_size_in_frames" : tile_size_in_frames , "tile_overlap_in_frames" : tile_overlap_in_frames ,
193- "use_two_stage_pipeline" : use_two_stage_pipeline , "use_distilled_pipeline" : use_distilled_pipeline ,
208+ "use_two_stage_pipeline" : use_two_stage_pipeline , "use_distilled_pipeline" : use_distilled_pipeline , "clear_lora_before_state_two" : clear_lora_before_state_two ,
194209 "video_patchifier" : self .video_patchifier , "audio_patchifier" : self .audio_patchifier ,
195210 }
196211 for unit in self .units :
@@ -417,8 +432,8 @@ def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
417432class LTX2AudioVideoUnit_NoiseInitializer (PipelineUnit ):
418433 def __init__ (self ):
419434 super ().__init__ (
420- input_params = ("height" , "width" , "num_frames" , "seed" , "rand_device" , "use_two_stage_pipeline" ),
421- output_params = ("video_noise" , "audio_noise" ,),
435+ input_params = ("height" , "width" , "num_frames" , "seed" , "rand_device" , "frame_rate" , " use_two_stage_pipeline" ),
436+ output_params = ("video_noise" , "audio_noise" , "video_positions" , "audio_positions" , "video_latent_shape" , "audio_latent_shape" )
422437 )
423438
424439 def process_stage (self , pipe : LTX2AudioVideoPipeline , height , width , num_frames , seed , rand_device , frame_rate = 24.0 ):
@@ -471,7 +486,6 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled,
471486 if pipe .scheduler .training :
472487 return {"video_latents" : input_latents , "input_latents" : input_latents }
473488 else :
474- # TODO: implement video-to-video
475489 raise NotImplementedError ("Video-to-video not implemented yet." )
476490
477491class LTX2AudioVideoUnit_InputAudioEmbedder (PipelineUnit ):
@@ -495,14 +509,13 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
495509 if pipe .scheduler .training :
496510 return {"audio_latents" : audio_input_latents , "audio_input_latents" : audio_input_latents , "audio_positions" : audio_positions , "audio_latent_shape" : audio_latent_shape }
497511 else :
498- # TODO: implement video-to-video
499- raise NotImplementedError ("Video-to-video not implemented yet." )
512+ raise NotImplementedError ("Audio-to-video not supported." )
500513
501514class LTX2AudioVideoUnit_InputImagesEmbedder (PipelineUnit ):
502515 def __init__ (self ):
503516 super ().__init__ (
504517 input_params = ("input_images" , "input_images_indexes" , "input_images_strength" , "video_latents" , "height" , "width" , "num_frames" , "tiled" , "tile_size_in_pixels" , "tile_overlap_in_pixels" , "use_two_stage_pipeline" ),
505- output_params = ("video_latents" ),
518+ output_params = ("video_latents" , "denoise_mask_video" , "input_latents_video" , "stage2_input_latents" ),
506519 onload_model_names = ("video_vae_encoder" )
507520 )
508521
@@ -537,6 +550,54 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_index
537550 return output_dicts
538551
539552
553+ class LTX2AudioVideoUnit_InContextVideoEmbedder (PipelineUnit ):
554+ def __init__ (self ):
555+ super ().__init__ (
556+ input_params = ("in_context_videos" , "height" , "width" , "num_frames" , "frame_rate" , "in_context_downsample_factor" , "tiled" , "tile_size_in_pixels" , "tile_overlap_in_pixels" , "use_two_stage_pipeline" ),
557+ output_params = ("in_context_video_latents" , "in_context_video_positions" ),
558+ onload_model_names = ("video_vae_encoder" )
559+ )
560+
561+ def check_in_context_video (self , pipe , in_context_video , height , width , num_frames , in_context_downsample_factor , use_two_stage_pipeline = True ):
562+ if in_context_video is None or len (in_context_video ) == 0 :
563+ raise ValueError ("In-context video is None or empty." )
564+ in_context_video = in_context_video [:num_frames ]
565+ expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor
566+ expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor
567+ current_h , current_w , current_f = in_context_video [0 ].size [1 ], in_context_video [0 ].size [0 ], len (in_context_video )
568+ h , w , f = pipe .check_resize_height_width (expected_height , expected_width , current_f , verbose = 0 )
569+ if current_h != h or current_w != w :
570+ in_context_video = [img .resize ((w , h )) for img in in_context_video ]
571+ if current_f != f :
572+ # pad black frames at the end
573+ in_context_video = in_context_video + [Image .new ("RGB" , (w , h ), (0 , 0 , 0 ))] * (f - current_f )
574+ return in_context_video
575+
576+ def process (self , pipe : LTX2AudioVideoPipeline , in_context_videos , height , width , num_frames , frame_rate , in_context_downsample_factor , tiled , tile_size_in_pixels , tile_overlap_in_pixels , use_two_stage_pipeline = True ):
577+ if in_context_videos is None or len (in_context_videos ) == 0 :
578+ return {}
579+ else :
580+ pipe .load_models_to_device (self .onload_model_names )
581+ latents , positions = [], []
582+ for in_context_video in in_context_videos :
583+ in_context_video = self .check_in_context_video (pipe , in_context_video , height , width , num_frames , in_context_downsample_factor , use_two_stage_pipeline )
584+ in_context_video = pipe .preprocess_video (in_context_video )
585+ in_context_latents = pipe .video_vae_encoder .encode (in_context_video , tiled , tile_size_in_pixels , tile_overlap_in_pixels ).to (dtype = pipe .torch_dtype , device = pipe .device )
586+
587+ latent_coords = pipe .video_patchifier .get_patch_grid_bounds (output_shape = VideoLatentShape .from_torch_shape (in_context_latents .shape ), device = pipe .device )
588+ video_positions = get_pixel_coords (latent_coords , VIDEO_SCALE_FACTORS , True ).float ()
589+ video_positions [:, 0 , ...] = video_positions [:, 0 , ...] / frame_rate
590+ video_positions [:, 1 , ...] *= in_context_downsample_factor # height axis
591+ video_positions [:, 2 , ...] *= in_context_downsample_factor # width axis
592+ video_positions = video_positions .to (pipe .torch_dtype )
593+
594+ latents .append (in_context_latents )
595+ positions .append (video_positions )
596+ latents = torch .cat (latents , dim = 1 )
597+ positions = torch .cat (positions , dim = 1 )
598+ return {"in_context_video_latents" : latents , "in_context_video_positions" : positions }
599+
600+
540601def model_fn_ltx2 (
541602 dit : LTXModel ,
542603 video_latents = None ,
@@ -549,6 +610,8 @@ def model_fn_ltx2(
549610 audio_patchifier = None ,
550611 timestep = None ,
551612 denoise_mask_video = None ,
613+ in_context_video_latents = None ,
614+ in_context_video_positions = None ,
552615 use_gradient_checkpointing = False ,
553616 use_gradient_checkpointing_offload = False ,
554617 ** kwargs ,
@@ -558,16 +621,25 @@ def model_fn_ltx2(
558621 # patchify
559622 b , c_v , f , h , w = video_latents .shape
560623 video_latents = video_patchifier .patchify (video_latents )
624+ seq_len_video = video_latents .shape [1 ]
561625 video_timesteps = timestep .repeat (1 , video_latents .shape [1 ], 1 )
562626 if denoise_mask_video is not None :
563627 video_timesteps = video_patchifier .patchify (denoise_mask_video ) * video_timesteps
628+
629+ if in_context_video_latents is not None :
630+ in_context_video_latents = video_patchifier .patchify (in_context_video_latents )
631+ in_context_video_timesteps = timestep .repeat (1 , in_context_video_latents .shape [1 ], 1 ) * 0.
632+ video_latents = torch .cat ([video_latents , in_context_video_latents ], dim = 1 )
633+ video_positions = torch .cat ([video_positions , in_context_video_positions ], dim = 2 )
634+ video_timesteps = torch .cat ([video_timesteps , in_context_video_timesteps ], dim = 1 )
635+
564636 if audio_latents is not None :
565637 _ , c_a , _ , mel_bins = audio_latents .shape
566638 audio_latents = audio_patchifier .patchify (audio_latents )
567639 audio_timesteps = timestep .repeat (1 , audio_latents .shape [1 ], 1 )
568640 else :
569641 audio_timesteps = None
570- #TODO: support gradient checkpointing in training
642+
571643 vx , ax = dit (
572644 video_latents = video_latents ,
573645 video_positions = video_positions ,
@@ -580,6 +652,8 @@ def model_fn_ltx2(
580652 use_gradient_checkpointing = use_gradient_checkpointing ,
581653 use_gradient_checkpointing_offload = use_gradient_checkpointing_offload ,
582654 )
655+
656+ vx = vx [:, :seq_len_video , ...]
583657 # unpatchify
584658 vx = video_patchifier .unpatchify_video (vx , f , h , w )
585659 ax = audio_patchifier .unpatchify_audio (ax , c_a , mel_bins ) if ax is not None else None
0 commit comments