o
    Met                    @   s  d dl Z d dlmZ d dlZd dlm  mZ d dlmZ d dl	m
Z
 d dlmZmZ d dlmZ d dlm  m  mZ ddlmZ d	d
lmZ d	dlmZmZ d	dlmZmZmZ d	dlmZm Z  d	dlm!Z!m"Z" d	dlm#Z#m$Z$m%Z% d	dl&m'Z' d	dlm(Z(m)Z) ddga*g da+ddgZ,dd Z-G dd dZ.G dd dZ/G dd dZ0G dd  d Z1G d!d" d"Z2G d#d$ d$Z3G d%d& d&Z4G d'd( d(Z5dS ))    N)reduce)unique_name)LayerHelper)ProgramOpProtoHolder)OpRole   )_get_global_env   )DistributedContext)OperatorDistributedAttributeTensorDistributedAttribute)new_process_groupProcessGroup_g_process_group_map)build_comm_descCommContext)AllgatherOpCost
SendOpCost)SliceOpCostSplitOpCostConcatOpCost)Cluster)print_program_with_dist_attris_gradient_clip_opZcheck_finite_and_unscaleZupdate_loss_scaling)sumsqrtfill_constantZelementwise_maxZelementwise_divwhileconditional_blockc                 C   s@   d}| |j v r|j |  }n|| }|dusJ d|j|S )z=Get var in the parent block if not found in the current blockNz{} is not found)varsZ_var_recursiveformatname)var_nameblockprogramvar r'   XD:\Projects\ConvertPro\env\Lib\site-packages\paddle/distributed/auto_parallel/reshard.pyget_var_with_recursion+   s   

r)   c                   @   sR   e Zd ZdZdddZedd Zedd Zed	d
 Zedd Z	dd Z
dS )AllGatherOpDescz
    Describe the allgather op in the reshard phase.

    Args:
        group (list): Process group.
        shape (list): The tensor shape.
        is_bool (bool): Whether allgather bool data. Default: False.
    Fc                 C   s   || _ d| _|| _|| _d S )NZ
all_gather)_group_desc_shape_is_bool)selfgroupshapeis_boolr'   r'   r(   __init__D   s   
zAllGatherOpDesc.__init__c                 C      | j S Nr.   r/   r'   r'   r(   r2   J      zAllGatherOpDesc.is_boolc                 C   r4   r5   )r+   r7   r'   r'   r(   r0   N   r8   zAllGatherOpDesc.groupc                 C   r4   r5   r,   r7   r'   r'   r(   descR   r8   zAllGatherOpDesc.descc                 C   r4   r5   r-   r7   r'   r'   r(   r1   V   r8   zAllGatherOpDesc.shapec              	   C   s&   d| j  d| j d| j d| j d	S )Nop: z	, group: 	, shape: , is_bool: .)r,   r+   r-   r.   r7   r'   r'   r(   __repr__Z   s   &zAllGatherOpDesc.__repr__NF)__name__
__module____qualname____doc__r3   propertyr2   r0   r:   r1   r@   r'   r'   r'   r(   r*   :   s    
	



r*   c                   @   j   e Zd ZdZdddZedd Zedd Zed	d
 Zedd Z	edd Z
edd Zdd ZdS )
SendOpDesca0  
    Describe the send op in the reshard phase.

    Args:
        partition_index (list): The index of partition in complete tensor.
        src (int): The source process to send.
        dst (int): The destination process to receive.
        is_bool (bool): Whether send bool data. Default: False.
    Fc                 C   s(   || _ || _d| _g | _|| _|| _d S )Nsend)_dst_partition_indexr,   r-   r.   _srcr/   partition_indexsrcdstr2   r'   r'   r(   r3   i      
zSendOpDesc.__init__c                 C   r4   r5   rL   r7   r'   r'   r(   rO   q   r8   zSendOpDesc.srcc                 C   r4   r5   r6   r7   r'   r'   r(   r2   u   r8   zSendOpDesc.is_boolc                 C   r4   r5   rK   r7   r'   r'   r(   rN   y   r8   zSendOpDesc.partition_indexc                 C   r4   r5   rJ   r7   r'   r'   r(   rP   }   r8   zSendOpDesc.dstc                 C   r4   r5   r9   r7   r'   r'   r(   r:      r8   zSendOpDesc.descc                 C   0   | j s| jD ]}| j |d |d   q| j S Nr
   r   r-   rN   appendr/   itemr'   r'   r(   r1         
zSendOpDesc.shapec                 C   .   d| j  d| j d| j d| j d| j dS Nr<   z, partition_index: z, dst: r=   r>   r?   r,   rK   rJ   r-   r.   r7   r'   r'   r(   r@         .zSendOpDesc.__repr__NrA   )rB   rC   rD   rE   r3   rF   rO   r2   rN   rP   r:   r1   r@   r'   r'   r'   r(   rH   ^        







rH   c                   @   rG   )
RecvOpDesca0  
    Describe the recv op in the reshard op.

    Args:
        partition_index (list): The index of partition in complete tensor.
        src (int): The source process to send.
        dst (int): The destination process to receive.
        is_bool (bool): Whether receive bool data. Default: False.
    Fc                 C   s(   || _ || _d| _g | _|| _|| _d S )Nrecv)rL   rK   r,   r-   r.   rJ   rM   r'   r'   r(   r3      rQ   zRecvOpDesc.__init__c                 C   r4   r5   rT   r7   r'   r'   r(   rP      r8   zRecvOpDesc.dstc                 C   r4   r5   r6   r7   r'   r'   r(   r2      r8   zRecvOpDesc.is_boolc                 C   r4   r5   rS   r7   r'   r'   r(   rN      r8   zRecvOpDesc.partition_indexc                 C   r4   r5   rR   r7   r'   r'   r(   rO      r8   zRecvOpDesc.srcc                 C   r4   r5   r9   r7   r'   r'   r(   r:      r8   zRecvOpDesc.descc                 C   rU   rV   rW   rY   r'   r'   r(   r1      r[   zRecvOpDesc.shapec                 C   r\   r]   r^   r7   r'   r'   r(   r@      r_   zRecvOpDesc.__repr__NrA   )rB   rC   rD   rE   r3   rF   rP   r2   rN   rO   r:   r1   r@   r'   r'   r'   r(   ra      r`   ra   c                   @   s^   e Zd ZdZdddZedd Zedd Zed	d
 Zedd Z	edd Z
dd ZdS )SliceOpDescac  
    Describe the slice op in the reshard phase.

    Args:
        starts (list): It represents start indices of corresponding axis in ``axes``.
        ends (list):  It represents end indices of corresponding axis in ``axes``.
        axes (list):  Axes that `starts` and `ends` apply to.
        shape (list): The shape of the tensor to be sliced.
    Nc                 C   s"   || _ || _|| _d| _|| _d S )Nslice)_starts_ends_axesr,   r-   )r/   startsendsaxesr1   r'   r'   r(   r3      s
   
zSliceOpDesc.__init__c                 C   r4   r5   )re   r7   r'   r'   r(   rh      r8   zSliceOpDesc.startsc                 C   r4   r5   )rf   r7   r'   r'   r(   ri      r8   zSliceOpDesc.endsc                 C   r4   r5   )rg   r7   r'   r'   r(   rj      r8   zSliceOpDesc.axesc                 C   r4   r5   r9   r7   r'   r'   r(   r:      r8   zSliceOpDesc.descc                 C   r4   r5   r;   r7   r'   r'   r(   r1      r8   zSliceOpDesc.shapec                 C   s^   | j d urd| j d| j d| j d| j d| j  dS d| j d| j d| j d| j d	S )Nr<   z
, starts: z, ends: z, axes: r=   r?   )r-   r,   re   rf   rg   r7   r'   r'   r(   r@      s   
.&zSliceOpDesc.__repr__r5   )rB   rC   rD   rE   r3   rF   rh   ri   rj   r:   r1   r@   r'   r'   r'   r(   rc      s    






rc   c                   @   s8   e Zd ZdZdd Zedd Zedd Zdd	 Zd
S )ConcatOpDescz
    Describe the concat op in the reshard phase.

    Args:
        partition_index_list (list): The list contains all partition index.
    c                 C   s   || _ d| _d S )Nconcat)_partition_index_listr,   )r/   partition_index_listr'   r'   r(   r3      s   
zConcatOpDesc.__init__c                 C   r4   r5   )rm   r7   r'   r'   r(   rn      r8   z!ConcatOpDesc.partition_index_listc                 C   r4   r5   r9   r7   r'   r'   r(   r:      r8   zConcatOpDesc.descc                 C   s   d| j  d| j dS )Nr<   z, partition_index_list: r?   )r,   rm   r7   r'   r'   r(   r@     s   zConcatOpDesc.__repr__N)	rB   rC   rD   rE   r3   rF   rn   r:   r@   r'   r'   r'   r(   rk      s    

rk   c                   @   s   e Zd ZdZedd Zedd Zedd Zedd	 Zed
d Z	edd Z
edddZedd Zedd Zedd ZdS )Inserterz*Insert op required in the reshard process.c              	   C   sj   t jjdddg}| j|||j|jd}| j|dd|gid|gi|j	|j	|dd	}|
d
d |S )Nr?   zcast@RESHARDtmpr"   dtypetype	lod_levelcastXOutZin_dtypeZ	out_dtypeop_rolers   inputsoutputsattrsop_namescope/auto_parallel/reshard)paddlefluidr   generate_with_ignorable_keyjoin
create_varrs   rt   
_insert_oprr   	_set_attr)r$   idxtensorry   Ztensor_typenew_var_nameoutZcast_opr'   r'   r(   insert_cast_op
  s&   	zInserter.insert_cast_opc           	   
   C   sN   d}t ||g}| j||d|gi|j|j|d|ddd}|dd dS )	z-Insert send op into block at the given index.send_v2rv   T)ring_idpeeruse_calc_streamry   dynamic_shape)rs   r{   r}   r~   r   N)r   r   idranksindexr   )	r$   r   r   rO   rP   ry   op_typeprocess_groupZsend_opr'   r'   r(   insert_send_op  s   

zInserter.insert_send_opc           	      C   s^   d}t ||g}| j||d|gid|gi|j|j||j|jd|ddd}|dd d	S )
z-Insert recv op into block at the given index.Zrecv_v2rv   rw   T)r   r   Z	out_shaperr   r   ry   r   rz   r~   r   N)r   r   r   r   r   r1   rr   r   )	r$   r   r   rO   rP   ry   r   r   Zrecv_opr'   r'   r(   insert_recv_op1  s    
zInserter.insert_recv_opc                 C   sf   t jjdddg}| j||j|j|j|j	d}| j
|d||dd|id|id	}|d
d |S )z2Insert reset_lod op into block at the given index.r?   zreset_lod@RESHARDrp   r"   r1   rs   rr   rt   	lod_resetrv   Yrw   ry   rz   r~   r   )r   r   r   r   r   r   r1   rs   rr   rt   r   r   )r$   r   rv   r   ry   r   reset_lod_outZreset_opr'   r'   r(   insert_reset_lod_opF  s&   zInserter.insert_reset_lod_opc           
   
   C   s   d|i}i }||d< ||d< t di t }tj| j( | jtjj	d
|jdg|d jd|d j|d jd	d	d
}W d   n1 sHw   Y  | j|d|d|gi|d}	|	dd |S )z/Insert concat op into block at the given block.rv   axisry   concat@RESHARDr?   rp   r   NFr"   rr   r1   rt   rs   persistablestop_gradientrl   rw   rz   r~   r   )r   )r   localsr   staticprogram_guardr%   r   r   r   r   r   r"   rr   rt   rs   r   r   )
r$   r   Ztensorsr   ry   r{   r}   helperr   Z	concat_opr'   r'   r(   insert_concat_op]  s4   
zInserter.insert_concat_opc                    s(  j }fddttD }	g }
t|	D ]\}}||| kr&|
| qt|
dkrZ j|jj|	jd}dgi}d|gi}ddi} j	|d	|||d
}|
dd |S t|
dkr|
d }|| |	|  }|}| |	|  }|}di}|||d}g }tj D ]\}}||kr|| q|||  qtj j  fddt|D }|| }W d   n1 sw   Y   j	|d|d|i|d
}|
dd |S di}tdd tt|D }|||d} j|jjjd} j	|d|d|gi|d
}|
dd |S )z.Insert slice op into block at the given block.c                    s   g | ]
} | |  qS r'   r'   .0i)ri   rh   r'   r(   
<listcomp>  s    z,Inserter.insert_slice_op.<locals>.<listcomp>r   )r"   rr   rs   r1   rt   rv   rw   Zin_placeFassignrz   r~   r   r
   numr   ry   c                    s>   g | ]} j tjjd ddgjdjdjddqS )r?   split@RESHARDrp   NF)r"   rr   r1   rs   r   rt   r   )	r   r   r   r   r   r   rr   rs   rt   r   )r$   r   r'   r(   r     s    

NsplitInputc                 s       | ]}d V  qdS r
   Nr'   r   r'   r'   r(   	<genexpr>      z+Inserter.insert_slice_op.<locals>.<genexpr>)rj   rh   ri   infer_flagsry   rq   rd   )r1   rangelen	enumeraterX   r   rr   rs   rt   r   r   r   r   r   r%   list)r$   r   r   rh   ri   rj   r   ry   Zglobal_shapeslice_shapeZ	diff_dimsr   rZ   r   r{   r|   r}   Zslice_opZdiff_dimnum_or_sectionsr   Zcur_idxinput_shape	new_shapeoutssplit_opr   r'   )r$   ri   rh   r   r(   insert_slice_opx  s   




zInserter.insert_slice_opr   c                    s   t di t j}di}|||d}g }	tjD ]\}
}|
|kr*|	| q|	||  qtj j  fddt	|D }W d   n1 sQw   Y   j
|d|d|i|d	}|d
d |S )z.Insert split op into block at the given index.r   rv   r   c                    s@   g | ]} j tjjd jdgjdjj	dddqS )r?   rp   NFr   )
r   r   r   r   r   r   r"   rr   rt   rs   r   r$   r   r   r'   r(   r     s    	z,Inserter.insert_split_op.<locals>.<listcomp>Nr   rw   rz   r~   r   )r   )r   r   r1   r   rX   r   r   r   r%   r   r   r   )r$   r   r   r   ry   r   r   r{   r}   r   r   rZ   r   r   r'   r   r(   insert_split_op  s,   	zInserter.insert_split_opc              	   C   s   t di t }tj| j" | jtjj	d
|jdgtjdtjjjddd}W d   n1 s4w   Y  i }ddi}ttd|d	< td|d
< |j|d< ||d< tj||dgdd | j|d|d|gi|d}d|_|dd |S )z6Insert fill constant op into block at the given index.fill_constant@RESHARDr?   rp   NF)r"   rr   r1   rs   r   r   Z	force_cpu1Z	str_valuevaluerr   ry   r   r   )r{   r}   r1   r   rw   rz   Tr~   r   )r   )r   r   r   r   r   r%   r   r   r   r   r   r"   int64coreZVarDescZVarTypeZ
LOD_TENSORstrintrr   utilsZget_shape_tensor_inputsr   r   r   )r$   r   ry   r   r   r{   r}   Zfillconstant_opr'   r'   r(   insert_fill_constant_op  sB   	
z Inserter.insert_fill_constant_opc              
   C   s  g }t |}d}| sNt| ||}d|_| j|d dd|gid|gidd|dd}	|	d	d
 | j|d dd|gid|gid|id}
|
d	d
 d}d}t|d fi t }t	j
| j" | jt	jjd|jdg|jd|j|jddd}W d   n1 sw   Y  | j|| |d|gid|gi|jd|j|dd}|d	d
 |d7 }t| || ||j|}|d7 }|| ||fS )z2Insert allgather op into block at the given index.r   Tr
   Zc_allreduce_sumrv   rw   )r   r   ry   rz   r~   r   r   Zc_sync_calc_streamry      c_allgather@RESHARDr?   rp   NFr   )r   r   nranksry   )r   Zis_instantiatero   r   r   r   r   r   r   r   r   r   r%   r   r   r   r   r   r"   rr   rt   rs   r   r   r   extend)r$   r   r   r   ry   tensor_listr0   
idx_offsetZfill_constant_outZallreduce_opZsync_calc_opr   r   Zallgather_outZallgather_opZ	split_outr'   r'   r(   insert_allgather_op  s|   




zInserter.insert_allgather_opc                 C   s   | s|  ||f dS d}d}|t| k rpt| | d |\}}	}
|dkrfd}|	dkr=t||d | | d |g||nt||d || | d g||}| | |d  d7  < t| ||
||| n
|d7 }|t| k s|s{|  ||f dS dS )z(Concat the tensors and insert concat op.r   Fr
   TN)rX   r   	Reshardercompute_concat_inforo   r   popconcat_partitions_with_op)partition_tensor_listr   rN   r$   r   ry   r   
has_concatconcat_axisfirst_ordernew_partition_r'   r'   r(   r   _  s4   
" 
z"Inserter.concat_partitions_with_opNr   )rB   rC   rD   rE   staticmethodr   r   r   r   r   r   r   r   r   r   r'   r'   r'   r(   ro     s,    





\ 
!
Fro   c                   @   s@   e Zd ZdZedd Zedd Zedd Zedd	 Zd
S )Removerz)Remove var and op in the reshard process.c              	   C   s  g d}g }t jD ]}|| q	t| jD ]\}}||vr#|| q|D ]}g }| j| }|j}|j}	t|D ]\}
}|jdkrog }|jD ]}|	t
||| j qGt|
ddD ]}|| jdkrm|| d|  nqZq9|jdkrg }|jD ]}|t
||| j}||jv r|| qy|s||
 q9t |j}|j|jd j| |j|jd j| q9||}|dur|j}||jvr|j|vr||
 q9|ddd D ]}
||
 qq&dS )	z&Remove no need ops in the main program)create_py_readercreate_double_buffer_readerreadr   r   r   Zshape_concatc_sync_comm_streamr   N)r   while_block_inforX   r   blocksopsr    rs   output_arg_namesr   r)   r1   r   r   input_arg_namesZ get_tensor_dist_attr_for_programprocess_mesh	processesr   instanceget_op_protor:   	set_inputr{   r"   
set_outputr|   get_op_dist_attr_for_program
_remove_op)auto_parallel_main_progdist_contextrank_idZnot_remove_op_refZremove_block_order	block_idxr$   remove_op_idxr   r    r   opZdim_listr#   r   Z	need_saver   protoop_dist_attrop_process_meshr'   r'   r(   remove_no_need_ops~  sx   











zRemover.remove_no_need_opsc                 C   s  t | jD ]\}}t }|j}|j}t }|D ] }	|	jD ]}
|
|v r'||
 q|	jD ]}
|
|v r6||
 q+q|D ]}||vrE|| q:|dkri }|D ]*}	t|		dtt
jkrxd|	jv rxd|	jv rx|	dd }|	dd }|||< qNg }t |D ]\}}|d j| vr|| q|ddd D ]}|| qd}|t|k r|| d j}|| d j}||| kr|| |||  f||< |d7 }|t|k s|D ]}||v rq|| qqdS )z'Remove no need vars in the main programr   ry   ParamZGradNr   r
   )r   r   setr   r    r   addr   r   attrr   ZOptimizeZinput_namesinputr"   keysrX   r   r   _remove_var)r   dist_params_gradsfeed_var_namesr   r$   remove_varsr   r    	need_varsr   r#   r&   Zparam_grad_map
param_nameZ	grad_nameZneed_remove_idxr   rZ   r'   r'   r(   remove_no_need_vars  sd   





zRemover.remove_no_need_varsc                 C   sV   t | || t| | g }tt|j g D ]}||j	 qt 
| || dS )z0Remove no need vars and ops in the main program.N)r   r   r    change_while_op_input_and_outputr   r   Zserial_feed_varsvaluesrX   r"   r
  )r   r   r   r  r  r&   r'   r'   r(   remove_no_need_in_main  s   zRemover.remove_no_need_in_mainc                 C   s0  t  }|  j}|D ]}|jD ]}|| qq
| }t  }|j}|D ]}|jdkr,q$|jD ]}|| q/q$t  }	|D ]}||v rH|	| q=|j}t  }
t|D ]2\}}d}|jdkr_qS|jD ]
}||	v rld} nqb|r|jD ]}|
| qr|jD ]}|
| q}qSt  }|jD ]}||
vr|| q|D ]}|	| qg }|j}t|jD ]Y\}}d}|jdkrg }|jD ]}||v r|
| q|s|
| nt |j}|j|jd j| |j|jd j| q|jD ]
}||vrd} nq|r|
| q|ddd D ]}|| qdS )z3Remove no need vars and ops in the startup program.r   FTr   Nr   )r   Zglobal_blockr   r   r   rs   r   r   r    r  rX   r   r   r   r:   r   r{   r"   r   r|   r   )r   auto_parallel_startup_progZmain_input_varsZmain_opsr   r#   Zstartup_blockZstartup_output_varsZstartup_opsr  Zactual_need_varsr   Z
is_need_opr  r&   r   r    Zis_no_need_opZ	var_namesr   r'   r'   r(   remove_no_need_in_startup  s   















z!Remover.remove_no_need_in_startupN)	rB   rC   rD   rE   r   r   r
  r  r  r'   r'   r'   r(   r   {  s    
?
3
r   c                   @   s  e Zd ZdZi Z	dQddZedd Zedd Zed	d
 Z	edd Z
edd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Zed!d" Zed#d$ Zd%d& Zd'd( Zd)d* Zd+d, ZdRd.d/Zd0d1 ZdSd3d4Zd5d6 Zd7d8 Z d9d: Z!d;d< Z"d=d> Z#d?d@ Z$dAdB Z%dCdD Z&dEdF Z'dGdH Z(dIdJ Z)dKdL Z*dMdN Z+dOdP Z,dS )Tr   a!  
    Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.

    Args:
        auto_parallel_main_prog (Program): An auto parallel main program.
        auto_parallel_startup_prog (Program): An auto parallel startup program.
        rank_id (int): The process id.
        dist_context (DistributedContext): The distributed context of this rank.
        dist_params_grads (list): The list contains the tuple of param and grad.
        batch_size (int): The batch size. Default: None.
    Nc                 C   s   t |tsJ dt||d ur t |ts J dt|t |ts.J dt|t |ts<J dt||d urNt |tsNJ dt||| _|| _|| _|| _	|| _
|| _i | _i | _i | _i | _d S )NzBThe type of auto_parallel_main_prog should be Program, but got {}.zMThe type of auto_parallel_startup_prog should be Program or None, but got {}.z.The type of rank_id should be int, but got {}.zBThe type of dist_context should be DistributedContext, but got {}.z1The type of batch_size should be int, but got {}.)
isinstancer   r!   rs   r   r   _auto_parallel_main_prog_auto_parallel_startup_prog_rank_id_dist_context_dist_params_grads_batch_size	_has_sent	_has_recv_has_allgather_has_resharded)r/   r   r  r   r   r  
batch_sizer'   r'   r(   r3   Z  s6   





zResharder.__init__c                 C   r4   r5   )r  r7   r'   r'   r(   r   {  r8   z!Resharder.auto_parallel_main_progc                 C   r4   r5   )r  r7   r'   r'   r(   r    r8   z$Resharder.auto_parallel_startup_progc                 C   r4   r5   )r  r7   r'   r'   r(   r     r8   zResharder.rank_idc                 C   r4   r5   )r  r7   r'   r'   r(   r     r8   zResharder.dist_contextc                 C   r4   r5   )r  r7   r'   r'   r(   r    r8   zResharder.dist_params_gradsc                 C   r4   r5   )r  r7   r'   r'   r(   r    r8   zResharder.batch_sizec                 C   r4   r5   )r  r7   r'   r'   r(   has_sent  r8   zResharder.has_sentc                 C   r4   r5   )r  r7   r'   r'   r(   has_recv  r8   zResharder.has_recvc                 C   r4   r5   )r  r7   r'   r'   r(   has_allgather  r8   zResharder.has_allgatherc                 C   sH   g }t | D ]\}}|| dkr|| q|||||    q|S )zCompute the shape of partition.r   r   rX   )complete_shapedims_mappingprocess_shapepartition_shaper   rZ   r'   r'   r(   compute_partition_shape  s   z!Resharder.compute_partition_shapec                 C   sh   | | }g }tdd |}tt|D ]}||||   }|||  }||| |  }|| q|S )z@Compute the index of process_shape corresponding to the process.c                 S   s   | | S r5   r'   )xyr'   r'   r(   <lambda>      z1Resharder.compute_process_index.<locals>.<lambda>)r   r   r   r   rX   )processr   r"  Zrelative_processprocess_indexproductr   r   r'   r'   r(   compute_process_index  s   
zResharder.compute_process_indexc           	      C   s   t |||}t | ||}g }tt|D ]+}|| dkr(|d|| g q||||  ||  |||  d ||  g q|S )z/Compute the partition index in complete tensor.r   r   r
   )r   r$  r,  r   r   rX   )	r)  r   r!  r"  r   r#  r*  rN   r   r'   r'   r(   compute_partition_index  s   z!Resharder.compute_partition_indexc                 C   s   d}d}d}g }t | D ]]\}}||| krd|d7 }|d || d kr>|d || d k r>|}||d || d g q|d || d krc|d || d krcd}|}||| d |d g q|| q|dkrs|||fS d||fS )zYJudge whether two partition can be concatenated and compute concatenated partition index.r   r   r
   r  )Zpartition_index_xZpartition_index_yZdiffer_countr   r   r   r   rZ   r'   r'   r(   r     s2   

zResharder.compute_concat_infoc                 C   sH   g }t | D ]\}}|| dkr|| q|||||    q|S )zVcompute the complete shape of the slice tensor  with its process mesh and dims mappingr   r  )r   r"  r!  r   r   rZ   r'   r'   r(   compute_complete_shape  s   z Resharder.compute_complete_shapec                 C   s   | s	|  | dS d}d}|t| k r:t| | |\}}}|dkr0d}| | t| | n
|d7 }|t| k s|sC|  | dS dS )z8Concat the given partitions without inserting concat op.r   Fr   Tr
   N)rX   r   r   r   r   concat_partitions)rn   rN   r   r   r   r   r   r'   r'   r(   r/    s(   


zResharder.concat_partitionsc                 C   s  t jD ]}| j| }t j| d }| j|j }t }g }|jD ]8}||}	|	s=|jdkr/|	r=|jdkr6|	r=|jdkrW|	sW|jD ]}
|
|vrK|	|
 q@|j
D ]}
||
 qOqd}|jD ]}|j |kro|jdkro|} nq]|du ruqt |j}g }|dD ]}
|
|v r|	|
 q|sJ |  |j|jd j| g }|d	D ](}
|ddd
 D ]}||
d
krt|
t|ksd|v r||vr|	| qq|sJ |j|jd j| qdS )zNChange while op input and output after the corresponding sub block ops removedop_idrd   r   r   Nr   rv   r   rw   r   r   )r   r   r   Z
parent_idxr   r   get_dist_op_for_programrs   r   rX   r   r   r:   r   r   r   r   r  sortr   r{   r"   outputfindr   r   r|   )r   r   Zsub_block_idx	sub_blockZparent_while_op_idZparent_blockZsub_block_op_inputsZsub_block_op_outputsr   dist_opr#   Zwhile_opr   Znew_XZnew_OutZoutput_namer'   r'   r(   r    sl   











z*Resharder.change_while_op_input_and_outputc                 C   sT   d}|d |d   kr|d k s&n |d |d   kr#|d k r(n |S d}|S )zBJudge whether two partitions intersect on the specified dimension.Fr   r
   Tr'   )r/   Zshape_xZshape_y
overlappedr'   r'   r(   is_overlappedB  s   6zResharder.is_overlappedc                 C   s   |D ]	}|dkr dS qdS )Nr   FTr'   )r/   r!  dimr'   r'   r(   
is_unshardJ  s
   zResharder.is_unshardc                 C   s(   |j tv rdS t|r|j tv rdS dS )NTF)rs   _g_special_opsr   _g_gradient_clip_ops)r/   r   r'   r'   r(   is_special_opP  s
   
zResharder.is_special_opc           
      C   s   | j j|dj }|jdkr|d}n
|jdkr|d}|D ]"}t||| j }| j|}|j	}|j
}|D ]
}	|	dkrB  dS q8q!dS )	Nr5  r   	Conditionr   ZCondr   FT)r   r   r  r   rs   r  r)   r   get_dist_tensor_for_program	dist_attrr!  )
r/   r   r5  Z
input_condr#   r&   dist_tensortensor_dist_attrZvar_dims_mappingr9  r'   r'   r(   is_condition_replicativeX  s$   


z"Resharder.is_condition_replicativeTc                 C   s   d}|j }|j}|j}|d }	|rK|d }
ttdd |||
|	grI||
krC|| jjvrA|D ]
}|dkr8tdq.|s=|S tdd	}||	krId	}|S |d }ttd
d ||||	grj||krdtd||	krjd	}|S )z/Judge the tensor whether needs to be resharded.Fr   r
   c                 S      | S r5   r'   r%  r'   r'   r(   r'  z      z(Resharder.need_reshard.<locals>.<lambda>r   z7The dim must be -1 when tensor process mesh is a union.zJit is not supported that tensor process mesh is a union and needs reshard.Tc                 S   rD  r5   r'   rE  r'   r'   r(   r'    rF  zVIt is not supported that tensor dims mapping is different from op output dims mapping.)r@  r!  r   allmapr   process_meshes
ValueError)r/   rA  r@  Zop_inputr6  Z
is_reshardrB  Ztensor_dims_mappingtensor_process_meshr   op_input_dims_mappingrZ   Zop_output_dims_mappingr'   r'   r(   need_reshardm  sX   zResharder.need_reshardc                 C   sj   g }| j |}|jj}| j jD ]}t|jt|j@ r+t|jt|jk r+|| q|s3|| |S )zEGet sub process meshes of the given op if op process mesh is a union.)	r   r1  r@  r   rI  r   r   r   rX   )r/   r   rI  r6  r   r   r'   r'   r(   get_op_process_meshes  s&   


zResharder.get_op_process_meshesFc           0   	   C   s  |j }|j}|j}|j}|j}|j}	|j}
|d }|d }|j}|j}|jd dk rC|jd dks3J t|j}| j	|d< |j
| |sMt|j|
|n|j}i }t|t|	rgt|t|	rg	 |S ||	krg }|	D ]X}t||||
|	}|s|||gdgg qptdd |D }tdd |D }tdd |D }||dkr||}|| | || d qp|||gdgg qp|D ]=}g }t|||||}g }g }|	D ]}t||||
|	}d	}td
d tt| j||D r||vrtdd |D |}tdd |D | }tdd |D | }d} | t|k rF||  s;||  }d|| < n| d7 } | t|k s-| t|kr^ttdd |}|d }d|d< |d	usgJ d|| vrrg ||< || vr}g ||< || |jjtjk}!t||||!d}"t||||!d}#|| |" || |# || t|| q|| t | g }$g }%g }&|d }'g }(t!|'D ]/\}})|$|| d |)d   |%|| d |)d   |&| |(|)d |)d   q|| t"|$|%|&|(d q|S g }g }g }*|	D ],}t||||
|	}||vr3|| |*|g|g q|*|| d | qt#t|*d d D ]} g }+t#t|*D ]},|+|*|, d |   | dkrp||*|, d  qU|+D ]c}-g }$g }%g }&t|-||||}t!|D ]\}})|$|)d  |%|)d  |&| q|$ }(t"|$|%|&|(d}.|sd	n|j%|-d}/t|+dkrt&|+|/|jtjkdt |d|.gn|.g||-< qtqK|S )a  
        Find the op description sequence to reshard the source tensor for matching the op requirement.

        Args:
            dist_tensor (DistributedTensor): A distributed tensor.
            dist_attr (list): A list contains process_mesh and dims_mapping such as [process_mesh, dims_mapping].
            serial (bool): If serial is true, the dist tensor and dist op come from serial program. Otherwise, they come from auto program.

        Returns:
            Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
            process and value is a list containing op description.
        r   r
   r   Fc                 S      g | ]}|d  qS r   r'   r   rZ   r'   r'   r(   r         z.Resharder.find_op_desc_seq.<locals>.<listcomp>c                 S   rO  r
   r'   rP  r'   r'   r(   r     rQ  c                 S   rO  r   r'   rP  r'   r'   r(   r     rQ  Nc                 s   s    | ]}|V  qd S r5   r'   )r   r   r'   r'   r(   r     r   z-Resharder.find_op_desc_seq.<locals>.<genexpr>c                 S   rO  r   r'   rP  r'   r'   r(   r         c                 S   rO  rS  r'   rP  r'   r'   r(   r     rT  c                 S   rO  rR  r'   rP  r'   r'   r(   r     rT  Tc                 S   s   dS NFr'   rE  r'   r'   r(   r'    rF  z,Resharder.find_op_desc_seq.<locals>.<lambda>z Failed to find the send process.)r2   )r1   )rh   ri   rj   r1   )rank)r0   r1   r2   )rn   )'r@  serial_tensorr"   r!  r   r   Ztopologyr1   r   r  r:   	set_shaper   r.  r   intersection
differencer-  rX   countr   rG  rH  r8  r   r  rr   r   boolrH   ra   r/  rk   r   rc   r   Zglobal_sizesZlocal_sizesr*   )0r/   rA  r@  serialrB  source_tensortensor_nameZsource_dims_mappingZsource_process_meshZsource_process_groupZsource_process_shapeZtarget_process_meshZtarget_dims_mappingZtarget_process_groupZtarget_process_shaper   r   op_desc_seqZpartition_process_mapping_listZsource_processZsource_partition_indexZpartition_listZprocess_listZhas_usedr   Ztarget_processr  Ztarget_partition_indexrn   Zall_partition_index_listZto_send_processr   r   r2   Zsend_op_descZrecv_op_descZslice_startsZ
slice_endsZslices_axesZconcatenated_partition_indexto_slice_tensor_shaperZ   r*  r0   jr)  Zslice_op_descZallgather_shaper'   r'   r(   find_op_desc_seq  s  


 # 



"




7





zResharder.find_op_desc_seqc           -         sX  g }g }j | vrdS |j  }d}	tt jD ]\}
}|jj|jjkr+|
}	 nq|	dus8J dj  j|	 }t| j	}|D ]}t
|tr
|j vr[g j|< j| ro|jttdd j| vr|jrt |	||dtj}t |	d ||j|d\}}|	|7 }	g }|D ]}t |	||dtj}||j |	d7 }	qj| |j|g nDt |	||j|d\}}|	|7 }	dd |D }j| |j|g nj| D ]}|j|d	 kr fd
d|d D } nq|s	J dqFt
|trj|j vrg j|< |jj| vri|jrMt |	||dtj}t |	d ||j|j|d |	d7 }	nt |	||j|j|d |	d7 }	j| |j qFt
|trz|j vr}i j|< |jj|  vrn|j }g }|D ]}
||
d |
d	   q|jr j!t"#|d ||j$tj|j%d}t& |	||j|j|d t |	d ||dtj}|| |	d7 }	|j| |j< qF j!t"#|d ||j$|j'|j%d}t& |	||j|j|d |j$d	kr\d}j	j(D ]B}|j)D ]6}|j)| }|j*rK|j$|j$krKt+ |	d |||d}|| |	d7 }	|j| |j< d} nq|rR nq|du s[J qF|| |	d7 }	|j| |j< qF|j| |j  qFt
|t,r|j-}|	g}t|D ]\}
}t.||||
  ||d q|d	 }	qFt
|t/r)t0|dks|rJ t0|dkr|d	 d	 n|}t"#|d } tj1 |	||j2|j3|j4| |dd}!|d	 }"|d }#t5 }$|#|$_6|"|$_7j89|!|$ |j%dkrFdt:j;|dj  vri t:j;|dj d< |t:j;|dj d  vr3g t:j;|dj d |< t:j;|dj d | ||!jg  jD ]}g }%|j<D ]}&j8=|}'|&|kr
|'dur
|j |j kr|j%dkr|&}(|!j} |(| ksJ |'>|(})|'?| |) |'@| |# |(|'jAv r|'B|( |%|  qP|jC|&|!j |&}(|!j} |(| ksJ |'>|(})|'?| |) |'@| |# |'B|( qP|'j7}*|'D|}+|*|"kr
|+|#kr
|jC|&|!j |&}(|!j} |(| ksJ |'>|(})|'?| |) |'@| |# |'B|( qP|%r'tEF G|j%},|jH|,jId	 j|Jd|%  qIqFdS )z1Parse op desc sequence and insert op in the blockNz:The op for reshard cannot be found in the rank {} program.c                 S   s   | d S )Nr   r'   rE  r'   r'   r(   r'    r(  z)Resharder.parse_op_desc.<locals>.<lambda>ry   r
   c                 S   s   g | ]}|j qS r'   )r"   )r   r&   r'   r'   r(   r     s    z+Resharder.parse_op_desc.<locals>.<listcomp>r   c                    s   g | ]	}t | jqS r'   )r)   r   )r   r#   r$   r/   r'   r(   r     s    z6The result of parsing allgather op should not be None.r   @recvr"   r1   rt   rr   rs   FTr   )rh   ri   rj   r   ry   r   var_reshard_mappingr5  rv   )Kr   r  r   r   r   r:   r   r!   r)   r   r  r*   r  r0   rH  r2   ro   r   r  r   r   r   r\  rX   r"   rH   r  rP   r   rO   ra   r  rN   r   r   generatert   rs   r   rr   r   r    is_datar   rk   rn   r   rc   r   r   rh   ri   rj   r   r!  r   r   Z set_tensor_dist_attr_for_programr   r   r   r   get_input_dist_attrset_input_dist_attrZset_input_dims_mappingZ_inputs_dist_attrsdel_input_dist_attr_rename_inputget_input_dims_mappingr   r   r   r   r{   r  )-r/   r$   r`  r#   Z
reshard_opr@  r   r   op_desc_listr   r   r   Z
matched_opr^  op_descZout_castr   Ztensor_name_listr&   rZ   rN   r1   Zrecv_tensorset_lod	tmp_blocktmp_var_nametmp_varr   rn   Zidx_listr   Zto_slice_tensornew_nameZtarget_tensorr   r!  Ztensor_attrZwhile_op_X_appendr"   r   old_nameop_input_dist_attrr   rL  r   r'   rd  r(   parse_op_desc  s\  































 zResharder.parse_op_descc                 C   s   |j tv sJ | jj|dj }|j}g }|D ]@}| j|}|s#q|j	}|j
D ].}||krW|j}	||}
d}|D ]}|	|d krM|
|d krMd} nq;|sW||	|
g q)q|S )Nr5  Fr   r
   T)rs   _g_subblock_opsr   r   r  r   r   r   r1  r@  r   r   rn  rX   )r/   r   r#   r5  r   input_attrsr6  r@  r"   r   input_dims_mappingZ	has_exist
input_attrr'   r'   r(   _get_subblock_input_attrs  s:   
z#Resharder._get_subblock_input_attrsc           
      C   s   g }| j |}|j}|j}| j jD ]}t|jt|j@ r-t|jt|jk r-|| q|s5|| |	|}g }	|D ]	}|	||g q>|	S r5   )
r   r1  r@  r   rI  r   r   r   rX   rn  )
r/   r   r#   rI  r6  r@  r   r   r{  rz  r'   r'   r(   _get_common_op_input_attrs  s0   



z$Resharder._get_common_op_input_attrsc                 C   s4   g }|j tv r| ||}n| ||}|sJ |S r5   )rs   ry  r}  r~  )r/   r   r#   op_input_attrsr'   r'   r(   get_op_input_attrs  s   
zResharder.get_op_input_attrsc                 C   s   t  }t| jj}|dkrnd}| jjD ]}|jD ]}|| qqt| jjD ]\}}tt |jt|kr:|} nq'|durpd}| jj| }t| jjD ]\}	}
|	|krVqMt |
jt |jk rbd}qM|rr| jj| dS dS dS dS )z;Remove global process mesh from dist_context.process_meshesr
   NFT)r   r   r   rI  r   r   r   r   )r/   r   Zprocess_mesh_countZglobal_process_mesh_idxr   r)  r   Z
is_removedZglobal_meshr   Zmeshr'   r'   r(   _remove_global_process_mesh  s:   
	z%Resharder._remove_global_process_meshc                 C   s  dt j| v rt j| d }|jD ]}|jD ]a}||v rw| j|}|j}d }|| D ]}	|j|	d d krG|||	d d krG|	d } nq+|d u rMq|j	
|| | j|}|j}
|}|}||ksgJ |
|}|
|| |
| q|jD ]D}||v rt|| dkrtd|| d d }|j	|| | j|}|j}
|}|}||ksJ |
|}|
|| |
| q{qd S d S )Nrg  r   r
   zpThe scene is not supported that the output is inplaced and the tensor has been resharded multiply when as input.)r   r   r   r   r   r1  r@  r   rn  r:   rm  rj  rk  rl  r   r   rJ  Z_rename_outputZget_output_dist_attrZset_output_dist_attrZdel_output_dist_attr)r/   r   r$   rg  r   r#   r6  r@  target_namerZ   r   rv  ru  rw  Zop_output_dist_attrr'   r'   r(   $_change_subblock_op_input_and_output  s   






z.Resharder._change_subblock_op_input_and_outputc              	   C   s  d}|t |jk rt |j}|j| }| |r|d7 }q| j|}|d urg }|jtv rX| |s8td|	dj
tjvrJi tj|	dj
< |j
 tj|	dj
 d< |jdkrc|d}n|jdkrn|d	}n|j}|  d}|D ]w}	d
|	v rqyt|	|| j}
| j|
}d}|jj| jjvr| jjrd}|jjdt |jjksJ | ||	}|D ];}d }|rt|d jt|jjjkrq|d ur| ||r| ||}| |||	|| t |j}|| | }|}qqy|| d }n|d7 }|t |jk s
d S d S )Nr   r
   zFPlease check the condition due to the dims mapping is not replicative.r5  r0  r   rv   r   r   Zlod_tensor_blocking_queueFTr   )r   r   r=  r   r1  rs   ry  rC  rJ  r  r   r   r   r:   r  r   r2  r)   r   r?  r@  r   rI  r!  r[  r  r   r   rM  rc  rx  )r/   r$   r   pre_op_countr   r6  Zop_input_dist_attrsZinput_var_namesr   r#   r&   rA  Zis_union_process_mesh_tensorr  r|  Zinput_process_meshreshard_op_desccur_op_countr'   r'   r(   _reshard_input  s   









zResharder._reshard_inputc                 C   sl  | j |kr4|jtjkr|jt|jd |j|j	tj
|jd}t||d ||||d d }|j	dkrd}	| jjD ]J}
|
jD ]@}|
j| }|jr|j	|j	kr|jt|jd |j|j|j|j	d}|d7 }|j|d	||d
d|id|did d}	 nqA|	r nq<|	du sJ |j|d dd|d u r|gn|gid|gi|j|j|ddd d S |j	dkr#|jt|jd |j|j	|j
|jd}t||d ||||d d}	| jjD ]9}
|
jD ].}|
j| }|jr|j	|j	kr|d7 }|j|d	||d
d|id|did d}	 nq|	r nq|	du s!J d S t||d ||||d d S d S )Nre  rf  r
   ry   r   Fz	@RESETLODr   r   r   rw   rz   Tr   ru   rv   rx   )r   rr   r   r\  r   r   rh  r"   r1   rt   r   rs   ro   r   r  r   r   r    ri  r   )r/   r$   r   r&   r   	send_rank	recv_rankZrecv_cast_outr   rq  rr  rs  rt  Zrecv_outr'   r'   r(   _hadnle_recvd  s   


	




	
zResharder._hadnle_recvc              
   C   sn   |j tjkr&t||d ||dtj}t||d ||||d d S t||d ||||d d S )Nr
   ry   r   )rr   r   r\  ro   r   r  r   r   )r/   r$   r   r&   r   r  r  Zcast_outr'   r'   r(   _handle_send  s   
zResharder._handle_sendc              
   C   s  d}g d}|t 7 }|t7 }|t|jk rHt|j}|j| }| j|}|d ur;|j|vr;d}|jD ]}t||| j	}	| j
|	}
|
jj}|jj|j|g}|
d ur3| |
|dr3t|jt|jt|d j@  }|r3t|t|d jkr|
jjdt|
jjks|d dt|d krtdt|D ]K\}}|}|}|t|d jkr|t|d j t|d j }|d j| }||krq| j|kr| |||	||| | j|kr| |||	||| qn8t|D ]3\}}|}|d j| }||krq| j|kr| |||	||| | j|kr%| |||	||| qt|j}|| | }|}q6|| d }n|d7 }|t|jk sd S d S )Nr   )r   r   r   Zwrite_to_arrayZread_from_arrayFr   r
   zThe dims_mapping must be -1)r;  ry  r   r   r   r1  rs   r   r)   r   r?  r@  r   Zget_output_dims_mappingrM  r   r   r!  r[  rJ  r   r   r  r  )r/   r$   r   Zskip_opsr  r   r6  r   r#   r&   rA  rK  Zoutput_attrZtensor_processesr   Ztensor_processr  Zactual_indexrZ   r  r'   r'   r(   _reshard_output  s   











zResharder._reshard_outputc                 C   sz   |    t| jjD ]\}}|tjv r| || | | | | q
t	
| j| j| j| j t	| j| j i t_d S r5   )r  r   r   r   r   r   r  r  r  r   r  r   r   r  r  r  )r/   r   r$   r'   r'   r(   reshard&  s   


zResharder.reshardc                 C   s  t dg }d }|j|v r|S |j}|dkr|S | j|}| j|}|j|j}	|jj}
|
|	g}|d ur| 	||r|| j
vrH|g| j
|< n'| j
| D ]}|j}||}|j}|	|krf||
krf|  S qM| j
| | | j||dd}|jj}| |||}|S )Nr   Zlod_tensor_blocking_queue_0T)r]  )r;  rs   r"   r   r?  r1  r@  rn  r   rM  r  rX   rc  rW  rr   parse_op_desc_for_cost)r/   r   r   clusterZnot_supported_op_typeZreshard_op_costr_  rA  r6  r!  r   r@  rZ   Zitem_dist_attrZitem_dims_mappingZitem_process_meshr  rr   r'   r'   r(   get_cost@  sP   


zResharder.get_costc                 C   s  |s	| | d S d}d}|t|k r~t|| |\}	}
}|	dkrtd}i }d|d< d|	i|d< |
dkrCd	||| f||fgi|d
< nd	||f||| fgi|d
< || ||vr^g ||< ||  t||d | |||||| n
|d7 }|t|k s|s| | d S d S )Nr   Fr   Trl   r   r   r}   rv   r{   rp  r  r
   )rX   r   r   r   r   r   _concat_partitions_for_cost)r/   r   rN   rr   r   local_rank_comp_costr  r   r   r   r   r   Zconcat_descr'   r'   r(   r  k  sP   




z%Resharder._concat_partitions_for_costc                 C   s  dd }t |}g }g }i }|D ]g}	g }
||	 }|D ][}t|tra|	|jg}|j}td|||}|||\}}|d u rQ||t||dfg |t| q|s`|| |t||df qt|t	r|j
}|j}td|||}g }t|D ]\}}|dkr||t|  qy|| qy|||\}}|d u r||t||dfg |t| n|s|| |t||df |	|vrg ||	< i }d|d< d	||fgi|d	< t|dd
|d< ||	 t||d qt|tr|j}t|D ]\}}| |
|||	|| qqt|trw|	|vrg ||	< t|
dks#|
r#J g }t|
dkr@|
d D ]}||d |d   q0n|j}i }d|d< tdd tt|jD }|j|j|j|d|d< d||fgi|d	< ||	 t||d qq||f}|S )Nc                 S   s   d\}}d}|t | k r>| | t|krd}|D ]}|| | v r)|}| | | q|d u r3|d7 }n	 ||fS |t | k s||fS )NrU  r   Tr
   )r   r   r   )
comm_ranksgroup_ranksresis_the_samer   rV  r'   r'   r(   _get_idx  s    
z2Resharder.parse_op_desc_for_cost.<locals>._get_idxr   )rp  comm_contextr   r   r   r   r{   )r   r   r}   r  r
   rd   c                 s   r   r   r'   r   r'   r'   r(   r     r   z3Resharder.parse_op_desc_for_cost.<locals>.<genexpr>)rj   rh   ri   r   r   )r   r  rH   rP   r1   r   rX   r   r   r*   r0   r   r   r   r   rk   rm   r  rc   r   r   rj   rh   ri   r   )r/   r  rr   r  r  r  Z
comm_costsr  r  keyr   ro  rp  r  r1   Z	send_descr   r  Zallgather_descZsplit_inputs_shaper9  Z
split_descrn   Zpartion_idexra  rZ   Z
slice_descr   r  r'   r'   r(   r    s   










Zz Resharder.parse_op_desc_for_costr5   )TNrA   )-rB   rC   rD   rE   r   r3   rF   r   r  r   r   r  r  r  r  r  r   r$  r,  r-  r   r.  r/  r  r8  r:  r=  rC  rM  rN  rc  rx  r}  r~  r  r  r  r  r  r  r  r  r  r  r  r'   r'   r'   r(   r   L  sz    
!
















6
8
 O  7MZ
^+)r   )6copy	functoolsr   r   Zpaddle.fluid.corer   r   Zpaddle.utilsr   Zpaddle.fluid.layer_helperr   Zpaddle.fluid.frameworkr   r   Z/paddle.distributed.fleet.meta_optimizers.commonr   Zpaddle.fluid.layers.utilsZlayersr   Z
collectiver	   r   r   Zdist_attributer   r   r   r   r   r   Zcostr   r   r   r   r   r   r   r  r   r   r   r;  r<  ry  r)   r*   rH   ra   rc   rk   ro   r   r   r'   r'   r'   r(   <module>   sB   $22-  v R