티스토리 뷰

728x90
반응형

안녕하세요. Teus입니다.

 

이번 포스팅은 SIMD(Single Instrunction Multi Data)를 다룹니다.

 

이번에는 AVX2의 기본 연산 중 쓸만한 함수 중 FMA연산을 정리합니다.

1. FMA

FMA는 Fused Multiply and Add 연산 입니다.

 

영단어에서 알 수 있듯

 

곱셈과 동시에 덧셈을 하는 연산입니다.

 

가장 간단하게 표현하면

ret = a*b + c

 

위 연산을 SIMD로 연산하는 명령어 입니다.

 

그럼 위 연산은 위해서 3개의 Vector가 필요하게 됩니다.

TMI
한번에 add와 mul연산을 동시에 할 수 있기 때문에, 한번에 2flops연산을 할 수가 있습니다.(한 cycle에 연산된다는 가정) 그래서 GPU에서는 core의 개수 * clock * 2flops로 성능이 측정됩니다. 그래서 현대 AI워크로드에 활용되는 GPU인 A100같은 경우 6912 * 1410 Mhz * 2Ops/hz = 19491840 => 19.5TFLOPs가 나오게 됩니다.(32bit Float 기준, 64bit Float기준으로는 절반이 됩니다)

Data Type Description
_mm/mm256_fmadd_ps/pd/ Multiply two vectors and add the product to a third (res = a * b + c)
_mm/mm256_fmsub_ps/pd/ Multiply two vectors and subtract a vector from the product (res = a * b - c)
_mm_fmadd_ss/sd Multiply and add the lowest element in the vectors (res[0] = a[0] * b[0] + c[0])
_mm_fmsub_ss/sd Multiply and subtract the lowest element in the vectors (res[0] = a[0] * b[0] - c[0])
_mm/mm256_fnmadd_ps/pd Multiply two vectors and add the negated product to a third (res = -(a * b) + c)
_mm/mm256_fnmsub_ps/pd Multiply two vectors and add the negated product to a third (res = -(a * b) - c)
_mm_fnmadd_ss/sd Multiply the two lowest elements and add the negated product to the lowest element of the third vector (res[0] = -(a[0] * b[0]) + c[0])
_mm_fnmsub_ss/sd Multiply the lowest elements and subtract the lowest element of the third vector from the negated product (res[0] = -(a[0] * b[0]) - c[0])
_mm/mm256_fmaddsub_ps/pd Multiply two vectors and alternately add and subtract from the product (res = a * b - c)
_mm/mm256_fmsubadd_ps/pd Multiply two vectors and alternately subtract and add from the product (res = a * b - c)

FMA연산같은 경우 보면 알겠지만, DNN에 많이 사용되게 됩니다.

DNN이 output = input * weight + bias로 계산되기 때문이죵

image

이때 특이한 연산으로, 뒤에 ss/sd가 붙는 연산이 있습니다.

 

해당 연산같은 경우

 

3개 Vector의 첫번째 값에 대해서만 FMA연산을 하고

 

나머지 값은 a vector로 채우는 연산 입니다.
_mm_fmadd_ss

image.png

그래서 아래와 같이 코드를 실행하면, 첫 Vector의 값만 FMA를 시키고, 치환하는 효과를 보여줍니다.
(역시 비전공자 입장에서, 해당 연산이 어디에 사용되는지는 잘...)

#include <stdio.h>
#include <stdlib.h>
#include <malloc.h>
#include <xmmintrin.h>
#include <immintrin.h >

int main() {            
    __m128 arr1 = _mm_setr_ps(0.5, 0.9, 0.1, 0.4);
    __m128 arr2 = _mm_setr_ps(0.1, 0.3, 0.7, 0.4);
    __m128 arr3 = _mm_setr_ps(0.3, 0.2, 0.9, 0.5);
    __m128 ret = _mm_fmadd_ss(arr1, arr2, arr3);
    float* my_ret = (float*)&ret;

    float* _arr1 = (float*)&arr1;
    float* _arr2 = (float*)&arr2;
    for (int i = 0; i < 4; i++) {
        printf("idx : %d, arr1[%d] = %f, arr2[%d] = %f\n", i, i, _arr1[i], i, _arr2[i]);
    }

    for (int i = 0; i < 8; i++) {
        printf("idx : %d, i : %f\n", i, my_ret[i]);
    }
}

image.png

728x90
반응형
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함