#ifndef TWN_VEC_H
#define TWN_VEC_H

/* vector ops for C users */

#include "twn_types.h"

#include <stdint.h>
#include <math.h>


static inline Vec2 vec2_add(Vec2 a, Vec2 b) {
    return (Vec2) { a.x + b.x, a.y + b.y };
}

static inline Vec2 vec2_sub(Vec2 a, Vec2 b) {
    return (Vec2) { a.x - b.x, a.y - b.y };
}

static inline Vec2 vec2_div(Vec2 a, Vec2 b) {
    return (Vec2) { a.x / b.x, a.y / b.y };
}

static inline Vec2 vec2_mul(Vec2 a, Vec2 b) {
    return (Vec2) { a.x * b.x, a.y * b.y };
}

static inline Vec2 vec2_scale(Vec2 a, float s) {
    return (Vec2) { a.x * s, a.y * s };
}

static inline float vec2_dot(Vec2 a, Vec2 b) {
    return a.x * b.x + a.y * b.y;
}

static inline float vec2_length(Vec2 a) {
    return sqrtf(a.x * a.x + a.y * a.y);
}

static inline Vec2 vec2_norm(Vec2 a) {
    const float n = sqrtf(vec2_dot(a, a));
    /* TODO: do we need truncating over epsilon as cglm does? */
    return vec2_scale(a, 1.0f / n);
}

static inline Vec3 vec3_add(Vec3 a, Vec3 b) {
    return (Vec3) { a.x + b.x, a.y + b.y, a.z + b.z };
}

static inline Vec3 vec3_sub(Vec3 a, Vec3 b) {
    return (Vec3) { a.x - b.x, a.y - b.y, a.z - b.z };
}

static inline Vec3 vec3_div(Vec3 a, Vec3 b) {
    return (Vec3) { a.x / b.x, a.y / b.y, a.z / b.z };
}

static inline Vec3 vec3_mul(Vec3 a, Vec3 b) {
    return (Vec3) { a.x * b.x, a.y * b.y, a.z * b.z };
}

static inline Vec3 vec3_scale(Vec3 a, float s) {
    return (Vec3) { a.x * s, a.y * s, a.z * s };
}

static inline float vec3_dot(Vec3 a, Vec3 b) {
    return a.x * b.x + a.y * b.y + a.z * b.z;
}

static inline float vec3_length(Vec3 a) {
    return sqrtf(a.x * a.x + a.y * a.y + a.z * a.z);
}

static inline Vec3 vec3_cross(Vec3 a, Vec3 b) {
    return (Vec3) {
        a.y * b.z - a.z * b.y,
        a.z * b.x - a.x * b.z,
        a.x * b.y - a.y * b.x,
    };
}

/* TODO: fast_sqrt version? */
static inline Vec3 vec3_norm(Vec3 a) {
    const float n = sqrtf(vec3_dot(a, a));
    /* TODO: do we need truncating over epsilon as cglm does? */
    return vec3_scale(a, 1.0f / n);
}

static inline float vec3_angle(Vec3 a, Vec3 b) {
    return acosf(vec3_dot(vec3_norm(a), vec3_norm(b)));
}

static inline Vec3 vec3_rotate(Vec3 v, float angle, Vec3 axis) {
    /* from cglm */
    Vec3 v1, v2, k;
    float c, s;

    c = cosf(angle);
    s = sinf(angle);

    k = vec3_norm(axis);

    /* Right Hand, Rodrigues' rotation formula:
        v = v*cos(t) + (kxv)sin(t) + k*(k.v)(1 - cos(t))
    */
    v1 = vec3_scale(v, c);

    v2 = vec3_cross(k, v);
    v2 = vec3_scale(v2, s);

    v1 = vec3_add(v1, v2);

    v2 = vec3_scale(k, vec3_dot(k, v) * (1.0f - c));
    v = vec3_add(v1, v2);

    return v;
}


/* TODO: remove. */
#define m_vec_add(p_any_vec0, p_any_vec1) (_Generic((p_any_vec0),   \
                                    Vec2:   vec2_add,               \
                                    Vec3:   vec3_add                \
                                )(p_any_vec0, p_any_vec1))

#define m_vec_sub(p_any_vec0, p_any_vec1) (_Generic((p_any_vec0),   \
                                    Vec2:   vec2_sub,               \
                                    Vec3:   vec3_sub                \
                                )(p_any_vec0, p_any_vec1))

#define m_vec_div(p_any_vec0, p_any_vec1) (_Generic((p_any_vec0),   \
                                    Vec2:   vec2_div,               \
                                    Vec3:   vec3_div                \
                                )(p_any_vec0, p_any_vec1))

#define m_vec_mul(p_any_vec0, p_any_vec1) (_Generic((p_any_vec0),   \
                                    Vec2:   vec2_mul,               \
                                    Vec3:   vec3_mul                \
                                )(p_any_vec0, p_any_vec1))

#define m_vec_scale(p_any_vec, p_any_scalar) (_Generic((p_any_vec),    \
                                    Vec2:   vec2_scale,                \
                                    Vec3:   vec3_scale                 \
                                )(p_any_vec, p_any_scalar))

#define m_vec_dot(p_any_vec0, p_any_vec1) (_Generic((p_any_vec0),   \
                                    Vec3:   vec3_dot                \
                                )(p_any_vec0, p_any_vec1))

#define m_vec_cross(p_any_vec0, p_any_vec1) (_Generic((p_any_vec0),      \
                                    Vec3:   vec3_cross                   \
                                )(p_any_vec0, p_any_vec1))

#define m_vec_norm(p_any_vec) (_Generic((p_any_vec),            \
                                    Vec3:   vec3_norm           \
                                )(p_any_vec))

#endif