Source Code
/*
░██████╗░█████╗░███████╗███████╗██████╗░██╗░░░██╗
██╔════╝██╔══██╗██╔════╝██╔════╝██╔══██╗╚██╗░██╔╝
╚█████╗░███████║█████╗░░█████╗░░██║░░██║░╚████╔╝░
░╚═══██╗██╔══██║██╔══╝░░██╔══╝░░██║░░██║░░╚██╔╝░░
██████╔╝██║░░██║███████╗███████╗██████╔╝░░░██║░░░
╚═════╝░╚═╝░░╚═╝╚══════╝╚══════╝╚═════╝░░░░╚═╝░░░

██╗░█████╗░██████╗░░█████╗░  ░█████╗░░██████╗░██████╗██╗██╗░░░██╗████████╗
██║██╔══██╗██╔══██╗██╔══██╗  ██╔══██╗██╔════╝██╔════╝██║██║░░░██║╚══██╔══╝
██║██║░░╚═╝██████╔╝██║░░╚═╝  ███████║╚█████╗░╚█████╗░██║██║░░░██║░░░██║░░░
██║██║░░██╗██╔═══╝░██║░░██╗  ██╔══██║░╚═══██╗░╚═══██╗██║██║░░░██║░░░██║░░░
██║╚█████╔╝██║░░░░░╚█████╔╝  ██║░░██║██████╔╝██████╔╝██║╚██████╔╝░░░██║░░░
╚═╝░╚════╝░╚═╝░░░░░░╚════╝░  ╚═╝░░╚═╝╚═════╝░╚═════╝░╚═╝░╚═════╝░░░░╚═╝░░░
*/
#define _CRT_SECURE_NO_WARNINGS
#include<bits/stdc++.h>
#include<cctype>
#include<climits>
#include<string>
#include<unordered_map>
#define testcase int t; cin>>t; while(t--)
#define pi acos(-1)
#define eps 1e-9
#define fix(n) cout <<fixed<<setprecision(n)
#define line cout << '\n';
#define sz(s)	(int)(s.size())
#define all(v) v.begin(),v.end()
#define allr(v) v.rbegin(),v.rend()
#define dpp(arr,val) memset(arr,val,sizeof(arr))
#define ull unsigned long long
#define ld long double
#define pq priority_queue
#define mp make_pair
#define S second
#define F first

using namespace std;
typedef long long   ll;

void file()
{
#ifndef ONLINE_JUDGE
    //freopen("take.txt", "r", stdin);
    //freopen("print.txt", "w", stdout);
#else
    //freopen("inc.in", "r", stdin);
    //freopen("out.txt", "w", stdout);
#endif
}

void fast()
{
    //ios::sync_with_stdio(0);cin.tie(0); cout.tie(0);
    std::ios_base::sync_with_stdio(0);
    cin.tie(NULL);
}

bool getbit(const int& num, const int& index)
{
    return (1 & (num >> index));
}

ll gcd(ll a, ll b)
{
    return (b == 0 ? abs(a) : gcd(b, a % b));
}

ll lcm(ll a, ll b)
{
    return a / gcd(a, b) * b;
}


void solve()
{
    int n, k;
    cin >> n >> k;
    vector<int>v(n);
    for (auto& e : v) {
        cin >> e;
    }
    sort(all(v));
    long long  ans = v[k - 1];
    long long cnt = ans;
    ans--;
    for (int i = k - 2; i >= 0; i--) {
        if (v[i] <= ans) {
            ans--;
        }
        else {
            cnt += v[i] - ans;
            ans--;
        }
    }
    cout << cnt;
}

int main() {
    fast();
    //file();
    //testcase
    solve();
}
Copy
Cutting Trees Saeedy
GNU G++17
23 ms
632 KB
Wrong Answer