import sim
import time
import numpy as np
import matplotlib.pyplot as plt


def get_angle(x, y, z):
    h, h_u, h_l = 0.15, 0.35, 0.382
    dyz = np.sqrt(y * y + z * z)
    lyz = np.sqrt(dyz * dyz - h * h)
    gamma_yz = -np.arctan(y / z)
    gamma_h_offset = -np.arctan(h / lyz)
    gamma = gamma_yz - gamma_h_offset
    lxzp = np.sqrt(lyz * lyz + x * x)
    n = (lxzp * lxzp - h_l * h_l - h_u * h_u) / (2 * h_u)
    beta = -np.arccos(n / h_l)
    alfa_xzp = -np.arctan(x / lyz)
    alfa_off = np.arccos((h_u + n) / lxzp)
    alfa = alfa_xzp + alfa_off
    return gamma, alfa, beta


def pose_control(row, pitch, yaw, pos_x, pos_y, pos_z):
    b, l, w, h = 0.4, 0.8, 0.7, 0.732
    R, P, Y = row * np.pi / 180, pitch * np.pi / 180, yaw * np.pi / 180
    pos = np.array((pos_x, pos_y, pos_z))

    rotx = np.array([[1, 0, 0], [0, np.cos(R), -np.sin(R)],
                     [0, np.sin(R), np.cos(R)]])
    roty = np.array([[np.cos(P), 0, -np.sin(P)], [0, 1, 0],
                     [np.sin(P), 0, np.cos(P)]])
    rotz = np.array([[np.cos(Y), -np.sin(Y), 0], [np.sin(Y),
                                                  np.cos(Y), 0], [0, 0, 1]])
    rot_mat = rotx * roty * rotz
    body_struct = np.array([[l / 2, -b / 2, h], [l / 2, b / 2, h],
                            [-l / 2, b / 2, h], [-l / 2, -b / 2, h]]).T

    footpoint_struct = np.array([[l / 2, -w / 2, 0], [l / 2, w / 2, 0],
                                 [-l / 2, w / 2, 0], [-l / 2, -w / 2, 0]]).T
    leg_pose = np.zeros((3, 4))
    for i in range(4):
        leg_pose[:,
                 i] = pos + rot_mat @ body_struct[:, i] - footpoint_struct[:,
                                                                           i]
    return leg_pose[0, 0], -leg_pose[1, 0], -leg_pose[2, 0], leg_pose[
        0, 1], leg_pose[1, 1], -leg_pose[2, 1], leg_pose[0, 2], leg_pose[
            1,
            2], -leg_pose[2, 2], leg_pose[0,
                                          3], -leg_pose[1, 3], -leg_pose[2, 3]


# test code
# row, pitch, yaw, pos_x, pos_y, pos_z = 0, 0, 0, 0, 0, -0.2
# pose_control(row, pitch, yaw, pos_x, pos_y, pos_z)


def trajectory():
    pass
    xs, xf, zs, h = -0.1, 0.1, -0.582, 0.1
    trajectory_x, trajectory_y = [], []
    for i in np.arange(0, 1, 0.01):
        pass
        sigma = 2 * np.pi * i
        xep = (xf - xs) * ((sigma - np.sin(sigma)) / (2 * np.pi)) + xs
        zep = h * (1 - np.cos(sigma)) / 2 + zs
        trajectory_x.append(xep)
        trajectory_y.append(zep)
    return trajectory_x, trajectory_y


def get_gait(t, T, gait_state=1):
    if gait_state == 2:
        Ts, xs, xf, zs, h = T / 4, -0.1, 0.1, -0.482, 0.15
        if t <= Ts:
            sigma = 2 * np.pi * t / Ts
            x = (xf - xs) * ((sigma - np.sin(sigma)) / (2 * np.pi)) + xs
            z = h * (1 - np.cos(sigma)) / 2 + zs
        else:
            x = (xs - xf) / (T - Ts) * (t - Ts) + xf
            z = -0.482
        return x, z
    elif gait_state == 1:
        Ts, xs, xf, zs, h = T / 2, -0.1, 0.1, -0.482, 0.1
        if t <= Ts:
            sigma = 2 * np.pi * t / Ts
            x = (xf - xs) * ((sigma - np.sin(sigma)) / (2 * np.pi)) + xs
            z = h * (1 - np.cos(sigma)) / 2 + zs
        else:
            x = (xs - xf) / (T - Ts) * (t - Ts) + xf
            z = -0.482
        return x, z


# trajectory_x, trajectory_y = trajectory()
# x = np.array([i for i in np.arange(0, 1, 0.01)])
# plt.plot(x, trajectory_x)
# plt.show()

print('Program started')
sim.simxFinish(-1)  # just in case, close all opened connections

clientID = sim.simxStart('127.0.0.1', 19999, True, True, 5000,
                         5)  # Connect to CoppeliaSim
if clientID != 1:
    print("Connected to remote API server")
    # Now try to retrieve data in a blocking fashion (i.e. a service call):
    sim.simxStartSimulation(clientID, sim.simx_opmode_oneshot)
    [rec, rb_rot_3] = sim.simxGetObjectHandle(clientID, 'rb_rot_3',
                                              sim.simx_opmode_blocking)
    [rec, rf_rot_3] = sim.simxGetObjectHandle(clientID, 'rf_rot_3',
                                              sim.simx_opmode_blocking)
    [rec, rb_rot_2] = sim.simxGetObjectHandle(clientID, 'rb_rot_2',
                                              sim.simx_opmode_blocking)
    [rec, rf_rot_2] = sim.simxGetObjectHandle(clientID, 'rf_rot_2',
                                              sim.simx_opmode_blocking)
    [rec, rb_rot_1] = sim.simxGetObjectHandle(clientID, 'rb_rot_1',
                                              sim.simx_opmode_blocking)
    [rec, rf_rot_1] = sim.simxGetObjectHandle(clientID, 'rf_rot_1',
                                              sim.simx_opmode_blocking)
    [rec, lb_rot_3] = sim.simxGetObjectHandle(clientID, 'lb_rot_3',
                                              sim.simx_opmode_blocking)
    [rec, lf_rot_3] = sim.simxGetObjectHandle(clientID, 'lf_rot_3',
                                              sim.simx_opmode_blocking)
    [rec, lb_rot_2] = sim.simxGetObjectHandle(clientID, 'lb_rot_2',
                                              sim.simx_opmode_blocking)
    [rec, lf_rot_2] = sim.simxGetObjectHandle(clientID, 'lf_rot_2',
                                              sim.simx_opmode_blocking)
    [rec, lb_rot_1] = sim.simxGetObjectHandle(clientID, 'lb_rot_1',
                                              sim.simx_opmode_blocking)
    [rec, lf_rot_1] = sim.simxGetObjectHandle(clientID, 'lf_rot_1',
                                              sim.simx_opmode_blocking)
    rb_rot_1_force, rb_rot_2_force, rb_rot_3_force = 500, 500, 500
    rf_rot_1_force, rf_rot_2_force, rf_rot_3_force = 500, 500, 500
    lb_rot_1_force, lb_rot_2_force, lb_rot_3_force = 500, 500, 500
    lf_rot_1_force, lf_rot_2_force, lf_rot_3_force = 500, 500, 500

    rec = sim.simxSetJointForce(clientID, rb_rot_3, rb_rot_3_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, rf_rot_3, rf_rot_3_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, rb_rot_2, rb_rot_2_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, rf_rot_2, rf_rot_2_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, rb_rot_1, rb_rot_1_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, rf_rot_1, rf_rot_1_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, lb_rot_3, lb_rot_3_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, lf_rot_3, lf_rot_3_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, lb_rot_2, lb_rot_2_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, lf_rot_2, lf_rot_2_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, lb_rot_1, lb_rot_1_force,
                                sim.simx_opmode_blocking)
    rec = sim.simxSetJointForce(clientID, lf_rot_1, lf_rot_1_force,
                                sim.simx_opmode_blocking)
    # time.sleep(1)
    # row, pitch, yaw, pos_x, pos_y, pos_z = 20, 0, 0, 0, 0, -0.3
    # rb_rot_1_pos, rb_rot_2_pos, rb_rot_3_pos = 0, 0, 0
    # rf_rot_1_pos, rf_rot_2_pos, rf_rot_3_pos = 0, 0, 0
    # lb_rot_1_pos, lb_rot_2_pos, lb_rot_3_pos = 0, 0, 0
    # lf_rot_1_pos, lf_rot_2_pos, lf_rot_3_pos = 0, 0, 0

    # while 1:
    #     [
    #         lb_x, lb_y, lb_z, rb_x, rb_y, rb_z, rf_x, rf_y, rf_z, lf_x, lf_y,
    #         lf_z
    #     ] = pose_control(row, pitch, yaw, pos_x, pos_y, pos_z)

    #     [lb_rot_1_pos, lb_rot_2_pos,
    #      lb_rot_3_pos] = get_angle(lb_x, lb_y, lb_z)
    #     [lf_rot_1_pos, lf_rot_2_pos,
    #      lf_rot_3_pos] = get_angle(lf_x, lf_y, lf_z)
    #     [rb_rot_1_pos, rb_rot_2_pos,
    #      rb_rot_3_pos] = get_angle(rb_x, rb_y, rb_z)
    #     [rf_rot_1_pos, rf_rot_2_pos,
    #      rf_rot_3_pos] = get_angle(rf_x, rf_y, rf_z)

    #     rec = sim.simxSetJointTargetPosition(clientID, lb_rot_1, -lb_rot_1_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, lb_rot_2, lb_rot_2_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, lb_rot_3, lb_rot_3_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, rf_rot_1, rf_rot_1_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, rf_rot_2, rf_rot_2_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, rf_rot_3, rf_rot_3_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, rb_rot_1, -rb_rot_1_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, rb_rot_2, rb_rot_2_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, rb_rot_3, rb_rot_3_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, lf_rot_1, lf_rot_1_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, lf_rot_2, lf_rot_2_pos,
    #                                          sim.simx_opmode_oneshot)
    #     rec = sim.simxSetJointTargetPosition(clientID, lf_rot_3, lf_rot_3_pos,
    #                                          sim.simx_opmode_oneshot)
    startTime = time.time()
    gait_state = 1
    while time.time() - startTime < 100:

        if time.time() - startTime < 5:

            if gait_state == 2:
                lb_x, rb_x, lf_x, rf_x = -0.1, 0.1, 0.1, -0.1
                lb_z, rf_z, lf_z, rb_z = -0.482, -0.482, -0.482, -0.482
            elif gait_state == 1:
                lb_x, rb_x, lf_x, rf_x = -0.1, -0.1, 0.1, 0.1
                lb_z, rf_z, lf_z, rb_z = -0.482, -0.482, -0.482, -0.482
            [rec, vrep_time] = sim.simxGetFloatSignal(clientID, 'time',
                                                      sim.simx_opmode_oneshot)

        [rec, vrep_realtime] = sim.simxGetFloatSignal(clientID, 'time',
                                                      sim.simx_opmode_oneshot)
        if gait_state == 1:
            T = 0.4
            time1 = vrep_realtime - vrep_time
            time2 = vrep_realtime - vrep_time + 0.2
            T1, T2 = time1 % T, time2 % T
            [lb_x, lb_z] = get_gait(T1, T)
            [rf_x, rf_z] = get_gait(T1, T)
            [rb_x, rb_z] = get_gait(T2, T)
            [lf_x, lf_z] = get_gait(T2, T)
        elif gait_state == 2:
            T = 1
            time1 = vrep_realtime - vrep_time
            time2 = vrep_realtime - vrep_time + 0.25
            time3 = vrep_realtime - vrep_time + 0.5
            time4 = vrep_realtime - vrep_time + 0.75
            T1, T2, T3, T4 = time1 % T, time2 % T, time3 % T, time4 % T
            [lb_x, lb_z] = get_gait(T1, T, 2)
            [rf_x, rf_z] = get_gait(T2, T, 2)
            [rb_x, rb_z] = get_gait(T3, T, 2)
            [lf_x, lf_z] = get_gait(T4, T, 2)
        [lb_rot_1_pos, lb_rot_2_pos,
         lb_rot_3_pos] = get_angle(lb_x, -0.15, lb_z)
        [lf_rot_1_pos, lf_rot_2_pos,
         lf_rot_3_pos] = get_angle(lf_x, -0.15, lf_z)
        [rb_rot_1_pos, rb_rot_2_pos,
         rb_rot_3_pos] = get_angle(rb_x, -0.15, rb_z)
        [rf_rot_1_pos, rf_rot_2_pos,
         rf_rot_3_pos] = get_angle(rf_x, -0.15, rf_z)
        rec = sim.simxSetJointTargetPosition(clientID, lb_rot_1, -lb_rot_1_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, lb_rot_2, lb_rot_2_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, lb_rot_3, lb_rot_3_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, rf_rot_1, rf_rot_1_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, rf_rot_2, rf_rot_2_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, rf_rot_3, rf_rot_3_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, rb_rot_1, -rb_rot_1_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, rb_rot_2, rb_rot_2_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, rb_rot_3, rb_rot_3_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, lf_rot_1, lf_rot_1_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, lf_rot_2, lf_rot_2_pos,
                                             sim.simx_opmode_oneshot)
        rec = sim.simxSetJointTargetPosition(clientID, lf_rot_3, lf_rot_3_pos,
                                             sim.simx_opmode_oneshot)

    sim.simxStopSimulation(clientID, sim.simx_opmode_blocking)
    sim.simxFinish(clientID)
else:
    print('Failed connecting to remote API server')

print('Program ended')