o
    Qe
n                     @   s   d dl Z ddlmZ ddlmZ ddlmZ ddlmZmZ ddl	m
Z
 dd	lmZ dd
lmZ d dlmZ d dlmZ d dlmZmZ d dlZd dlmZmZ g ZG dd deZdS )    N   )	Optimizer   )core)	framework)Variable
name_scope)LayerHelper)unique_name)layers)L2DecayRegularizer)_C_ops_legacy_C_ops)in_dygraph_mode_in_legacy_dygraphc                       s   e Zd ZdZdZ										d fdd		Zd
d Zdd Zdd Zdd Z	d fdd	Z
dd Zdd Zdd Zdd Z  ZS )Momentuma  

    Simple Momentum optimizer with velocity state

    This optimizer has a flag for Nestrov Momentum.

    The update equations are as follows:

    .. math::

        & velocity = mu * velocity + gradient

        & if (use\_nesterov):

        &\quad   param = param - (gradient + mu * velocity) * learning\_rate

        & else:

        &\quad   param = param - learning\_rate * velocity

    Parameters:

        learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
            It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
        momentum (float): Momentum factor. The default value is 0.9.
        parameters (list|tuple, optional): List|Tuple of ``Tensor`` to update to minimize ``loss``. \
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
            The default value is None in static mode, at this time all parameters will be updated.
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
            Default None, meaning there is no regularization.
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
        rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \
            Often choose to be ``1.0/batch_size``.
        use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python

            import paddle

            inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")
            momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
            back = out.backward()
            momentum.step()
            momentum.clear_grad()

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            momentum = paddle.optimizer.Momentum(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1
                }],
                weight_decay=0.01,
                momentum=0.9)                   
            out.backward()
            momentum.step()
            momentum.clear_grad()

    velocityMbP??NF      ?c                    s^  |d u rt d|d u rt ddd }t|trJt|d trJ|D ]'}d|v r,|d n|}| |\}}||d< ||d< ||rCd n|}||d< q"||rPd n|}tt| j|||||
d	 d
| _|| _	t
|| _| |\| _| _|| _|| _i | _|||| j| jd| _|	| _| jr|  | _|  | _|  | _d | jd< |  | _|  | _d S d S )Nzlearning_rate is not setzmomentum is not setc                 S   s   t | ttfS N)
isinstancer   float)Zregular r   ID:\Projects\ConvertPro\env\Lib\site-packages\paddle/optimizer/momentum.py<lambda>   s    z#Momentum.__init__.<locals>.<lambda>r   weight_decayregularization_methodregularization_coeff)learning_rate
parametersr   	grad_clipnamemomentum)r#   use_nesterovrescale_gradr   r   FP32_LODTensor)
ValueErrorr   listdict_update_regularizationsuperr   __init__type	_momentumbool_use_nesterov_regularization_method_regularization_coeff_multi_precision_rescale_grad_master_weights_default_dictZ_use_multi_tensorZ_create_multi_tensor_dict_param_dict_velocity_dict_master_weight_dict_regularization_method_dict_regularization_coeff_dict)selfr   r#   r    r$   r   r!   multi_precisionr%   Zuse_multi_tensorr"   	predicateZparam_groupZdecay
reg_method	reg_coeffZ
py_regular	__class__r   r   r,   ~   sd   








zMomentum.__init__c                 C   s6   d}d}t |trd}|j}t |trd}|}||fS )N         l2_decay)r   r   r2   r   )r<   r   r?   r@   r   r   r   r*      s   

zMomentum._update_regularizationc                 C   s   |j | jv r| j|j  }|S t| jtsJ |j d }t|}tj||j	dddd}| jj
 }|jdd|gid|gi|jtjjjd	d
 || j|j < |S )NZ_fp32_masterr   float32T)r"   shapevaluedtypeZpersistablecastXZOut)Zin_dtypeZ	out_dtype)r-   inputsoutputsattrs)r"   r5   r   helperr	   r
   generater   Zcreate_global_varrG   Zstartup_programZglobal_block	append_oprI   r   VarDescVarTypeZFP32)r<   paramvarvar_nameblockr   r   r   _create_master_weight   s0   

	zMomentum._create_master_weightc                 C   s~   | j dur| j d | }| jo|jtjjjk}|r| j|j n|}|j}|| j	vs0|| j	| vr8t
d||| j	| | S )a  Utility function to fetch an accumulator for a parameter

        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched

        Returns:
            accumulator variable for the parameter
        N_z.Accumulator {} does not exist for parameter {})_namer3   rI   r   rR   rS   FP16r5   r"   Z_accumulators	Exceptionformat)r<   r"   rT   find_masterZtarget_paramtarget_namer   r   r   _get_accumulator   s   


zMomentum._get_accumulatorc                 C   s   t |tjsJ t |tr| |}|D ]1}| jr.|jtjj	j
kr.| |}| | j| q|jtjj	j
kr>| js>td | | j| qdS )zE
        if framework._non_static_mode():
            return
        zAccumulating with FP16 in optimizer can lead to poor accuracy or slow convergence.Consider using multi_precision=True option of the Momentum optimizer.N)r   r   Blockr)   _update_param_groupr3   rI   r   rR   rS   r[   rX   Z_add_accumulator_velocity_acc_strwarningswarn)r<   rW   r    pZmaster_pr   r   r   _create_accumulators
  s    


zMomentum._create_accumulatorsc                    s.   t |drt|jtr|S tt| |||S )zpCreate and add backward regularization Operators

        Function helper of append_regularization_ops.
        regularizer)hasattrr   rh   r   r+   r   _create_regularization_of_grad)r<   rT   ZgradZregularizationrA   r   r   rj   #  s   
z'Momentum._create_regularization_of_gradc                 C   s  t |tjsJ t |tr| |}| | j|d }| |}|d }| j}| j	}t
|drEt |jtr<d}|jj	}n	|jd urEd}d}| joQ|d jtjjjk}|r\| j|d j nd }	t rt |trm| |d  t|d |d |||	|d ||	d| jd	| jd
|d|d|\}
}
}
d S t rt |tr| |d  t|d |d |||	| j| j|||| jS | j| j|||| jd}|d g|d g|g|gd}|d g|gd}|r|	|d< |	|d< |j| j |||dd}|S )Nr   rh   rE   rC   rD   r   r   mur$   r   r   r=   )rk   r$   r   r   r=   r%   ParamZGradVelocityZLearningRateZParamOutZVelocityOutMasterParamMasterParamOutTr-   rL   rM   rN   stop_gradient)!r   r   ra   r)   rb   r`   rc   _create_param_lrr1   r2   ri   rh   r   r3   rI   r   rR   rS   r[   r5   r"   r   r*   r   r#   r.   r0   r   r   Z	momentum_r4   rQ   r-   )r<   rW   param_and_gradvelocity_acclrrT   r   r   r^   master_weightrY   rN   rL   rM   Zmomentum_opr   r   r   _append_optimize_op2  s   










zMomentum._append_optimize_opc                 C   s^  |  || |D ]}| | j|}| j}| j}t|dr2t|jtr)d}|jj}n	|jdur2d}d}|j	t
jkra| jd | | | jd | | | jd | | | jd | | q|j	t
jkr| jd | | | jd | | | jr| jd | | j|j  nd| jd |< | jd | | | jd | | qtddS )	a  
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
        This function will be overridden in the corresponding optimizer file.

        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        rh   rE   NrC   rD   r&   FP16_LODTensorzWNow multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR.)rg   r`   rc   r1   r2   ri   r   rh   r   rI   paddlerF   r7   appendr8   r:   r;   float16r3   r9   r5   r"   r'   )r<   target_blockr    param_group_idxrT   rv   r   r   r   r   r   _multi_tensor_init  sp   	



zMomentum._multi_tensor_initc                 C   s  t |tjsJ g g d}g g d}t |trz|D ]_}|d du r"q|d jdu rx|d jtjkrQ|d jt	j
jjkrQ|d |d  | |}|d | q|d jtjkrx|d jt	j
jjkrx|d |d  | |}|d | qn||d D ]w}|d du rq~|d jdu rt }||d< |d	d
 | D  | |}|d jtjkr|d jt	j
jjkr|d |d  | |}|d | q~|d jtjkr|d jt	j
jjkr|d |d  | |}|d | q~ddg}	|	D ]}
t| j|
 | dkr| jo|
dk}| j|
 }|dur || nd}t rt rYt| j|
 | ||
 | j|
 | ||
 || j| j| j|
 | | j |
 | || j!\}}}qt"#| j|
 | ||
 | j|
 | ||
 || j|
 | | j|
 | |d| jd| jd| j|
 | d| j |
 | d|\}}}q| j|
 | ||
 | j|
 | ||
 d}| j|
 | | j|
 | d}| j| j| j|
 | | j |
 | d}|r| j|
 | |d< | j|
 | |d< ||d< |j$d|||dd qdS )zM
        For Multi Tensor, append optimize merged_operator to block.
        )r&   rz   r   Nr   Fr&   rz   paramsc                 S   s   i | ]\}}|d kr||qS )r   r   ).0kvr   r   r   
<dictcomp>  s
    z=Momentum._append_optimize_multi_tensor_op.<locals>.<dictcomp>rk   r$   r   r   r=   rl   ro   )rk   r$   r   r   rp   rq   merged_momentumTrr   )%r   r   ra   r(   rs   rI   r{   rF   r-   r   rR   rS   Z
LOD_TENSORr|   rt   r}   r)   updateitemsrb   lenr7   r3   r9   Z_non_static_moder   r   Zmerged_momentum_r8   r.   r0   r:   r;   r4   r   r   rQ   )r<   r~   Zparameters_and_gradsr   Z	grad_dictZlr_dictru   rw   Zparam_grad_dictZmulti_tensor_listkeyr^   rx   rY   rL   rM   rN   r   r   r    _append_optimize_multi_tensor_op  s"  	










z)Momentum._append_optimize_multi_tensor_opc                 C   sr   | d| jd | _| d| jd | _| d| jd | _| d| jd | _| d| jd | _| d}|S )Nr#   r$   r%   r   r   r   )getr6   r.   r0   r4   r1   r2   )r<   r    r   r   r   rb     s"   





zMomentum._update_param_group)
r   r   NFNNFr   FNr   )__name__
__module____qualname____doc__rc   r,   r*   rX   r`   rg   rj   ry   r   r   rb   __classcell__r   r   rA   r   r   !   s0    ZFo@ &r   )rd   Z	optimizerr   Zfluidr   r   Zfluid.frameworkr   r   Zfluid.layer_helperr	   r
   r   Zpaddle.fluidZpaddle.fluid.regularizerr   r{   r   r   Zpaddle.fluid.frameworkr   r   __all__r   r   r   r   r   <module>   s   